<a href="https://colab.research.google.com/github/danielsaggau/IR_LDC/blob/main/Rezaei_Bregman.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn

class BregmanLoss(nn.Module):
    def __init__(self, batch_size, temperature, sigma):
        super(BregmanLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.sigma = sigma

        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        
        #self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask
    
    def b_sim(self, features):
        mm = torch.max(features, dim=1)
        indx_max_features = mm[1]
        max_features = mm[0].reshape(-1, 1)
        
        # Compute the number of active subnets in one batch
        eye = torch.eye(features.shape[1])
        one = eye[indx_max_features]
        num_max = torch.sum(one, dim=0)
        
        dist_matrix = max_features - features[:, indx_max_features]
        
        case = 2
        if case == 0:
            m2 = torch.divide(dist_matrix, torch.max(dist_matrix))
            sim_matrix = torch.divide(torch.tensor([1]).to(features.device), m2 + 1)
            
        if case == 1:
            gamma = torch.tensor([1]).to(features.device)
            sim_matrix = torch.exp(torch.mul(-dist_matrix, gamma))
            
        if case == 2:
            sigma = torch.tensor([self.sigma]).to(features.device)
            sig2 = 2 * torch.pow(sigma, 2)
            sim_matrix = torch.exp(torch.div(-dist_matrix, sig2))
        
        if case == 3:
            sim_matrix = 1 - dist_matrix
            
        return sim_matrix, num_max

    def forward(self, out_a, out_b):
        
        N = 2 * self.batch_size

        features = torch.cat((out_a, out_b), dim=0)
        
        ###################################################
        ### Computing Similarity Matrix ###################
        sim_matrix, num_max = self.b_sim(features)
        sim_matrix = sim_matrix / self.temperature
        ###################################################
        #sim_matrix = self.similarity_f(out.unsqueeze(1), out.unsqueeze(0)) / self.temperature

        pos_ab = torch.diag(sim_matrix, self.batch_size)
        pos_ba = torch.diag(sim_matrix, -self.batch_size)

        positives = torch.cat((pos_ab, pos_ba), dim=0).reshape(N, 1)
        negatives = sim_matrix[self.mask].reshape(N, -1)

        labels = torch.zeros(N, dtype=torch.long).to(features.device)
        logits = torch.cat((positives, negatives), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss, num_max

In [None]:
import torch
import torch.nn as nn


class NT_Xent(nn.Module):
    def __init__(self, batch_size, temperature):
        super(NT_Xent, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature

        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask
    
    def forward(self, out_a, out_b):
        
        N = 2 * self.batch_size

        out = torch.cat((out_a, out_b), dim=0)
        
        ###################################################
        ### Computing Similarity Matrix ###################
        sim_matrix = self.similarity_f(out.unsqueeze(1), out.unsqueeze(0)) / self.temperature
        ###################################################
        

        pos_ab = torch.diag(sim_matrix, self.batch_size)
        pos_ba = torch.diag(sim_matrix, -self.batch_size)

        positives = torch.cat((pos_ab, pos_ba), dim=0).reshape(N, 1)
        negatives = sim_matrix[self.mask].reshape(N, -1)
        
        #######################################################
        ### New loss
        #negatives = negatives.reshape(-1, 1)
        #negatives, negatives_indices = negatives.topk(k=(N-10)*N, largest=False, dim=0)
        #negatives = negatives.reshape(N, -1)
        #######################################################

        labels = torch.zeros(N, dtype=torch.long).to(out.device)
        logits = torch.cat((positives, negatives), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss

In [4]:
import pandas as pd
import torch
from tqdm import tqdm
#from loss.breg_loss import BregmanLoss
from loss.nt_xent import NT_Xent
#from loss.breg_margin_loss import BregMarginLoss


class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

class Trainer():
    def __init__(self, model,
                 optimizer,
                 scheduler,
                 temperature,
                 num_cls,
                 epochs,
                 sigma,
                 lmbda,
                 device):
        super(Trainer, self).__init__()
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.temperature = temperature
        self.num_cls = num_cls
        self.epochs = epochs
        self.sigma = sigma
        self.lmbda = lmbda
        self.device = device
        self.mixed_loss = True

    # train for one epoch to learn unique features
    def train(self, data_loader, epoch):
        self.model.train()
        batch_size = data_loader.batch_size
        #bloss = BregMarginLoss(batch_size)
        bloss = BregmanLoss(batch_size, self.temperature, self.sigma)
        nt_xent = NT_Xent(batch_size, self.temperature)
        
        total_loss, total_num, tot_max, train_bar = 0.0, 0, 0, tqdm(data_loader)
        tot_bloss, tot_nt_xent = 0.0, 0.0
        num_max = torch.tensor([0])
        for [aug_1, aug_2], target in train_bar:
            aug_1, aug_2 = aug_1.to(self.device), aug_2.to(self.device)
            feature_1, out_1 = self.model(aug_1)
            feature_2, out_2 = self.model(aug_2)

            # compute loss
            loss, num_max = bloss(out_1, out_2)
            tot_bloss += loss.item() * batch_size
            if self.mixed_loss:
                loss1 = nt_xent(feature_1, feature_2)
                tot_nt_xent += loss1.item() * batch_size
                loss = loss + self.lmbda * loss1
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            tot_max += num_max
            total_num += batch_size
            total_loss += loss.item() * batch_size
            train_bar.set_description(
                '{}Train{} {}Epoch:{} [{}/{}] {}Loss:{}  {:.4f} {}Active Subs:{} [{}/{}]'
                .format(
                    bcolors.OKCYAN, bcolors.ENDC,
                    bcolors.WARNING, bcolors.ENDC,
                    epoch,
                    self.epochs,
                    bcolors.WARNING, bcolors.ENDC,
                    total_loss / total_num,
                    bcolors.WARNING, bcolors.ENDC,
                    len(torch.where(tot_max>10)[0]),
                    tot_max.shape[0]))
            
        # warmup with nt_xent loss for the first 50 epochs
        #if epoch >= 100:
        self.scheduler.step()

        return (total_loss/total_num,
                tot_bloss/total_num,
                tot_nt_xent/total_num,
                self.scheduler.get_last_lr()[0])
    
    
    def bregman_sim(self, feature, feature_bank):
        # [B, 1]
        mf = torch.max(feature, dim=1)
        # [N, 1]
        mfb = torch.max(feature_bank, dim=1)
        indx_max_feature_bank = mfb[1]
        max_feature = mf[0].reshape(-1, 1)
        # [B, N]
        dist_matrix = max_feature - feature[:, indx_max_feature_bank]
        # Computing Similarity from Bregman distance
        sigma = torch.tensor([1.]).to(self.device)
        sigma = 2 * torch.pow(sigma, 2)
        sim_matrix = torch.exp(torch.div(-dist_matrix, sigma))
        
        return sim_matrix
        
    # test for one epoch, use weighted knn to find the most similar images' label to assign the test image
    def test(self, memory_data_loader, test_data_loader, k_nn, epoch):
        self.model.eval()
        total_top1, total_top5, total_num, feature_bank, feature_labels = 0.0, 0.0, 0, [], []
        
        with torch.no_grad():
            # generate feature bank
            
            for [data, _], target in tqdm(memory_data_loader,
                                        desc=f'{bcolors.OKBLUE}Feature extracting{bcolors.ENDC}'):
                feature, out = self.model(data.to(self.device))
                feature_bank.append(out)
                feature_labels.append(target)
            # [N, D]
            feature_bank = torch.cat(feature_bank, dim=0)
            feature_labels = torch.cat(feature_labels, dim=0).long().to(self.device)
            # [N]
            
            #feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=self.device)
            # loop test data to predict the label by weighted knn search
            test_bar = tqdm(test_data_loader)
            for [data, _], target in test_bar:
                data, target = data.to(self.device), target.to(self.device)
                feature, out = self.model(data)

                total_num += data.size(0)
                # compute bregman similarity between each feature vector and feature bank ---> [B, N]
                sim_matrix = self.bregman_sim(out, feature_bank)
                # [B, K]
                sim_weight, sim_indices = sim_matrix.topk(k=k_nn, dim=-1)
                # [B, K]
                sim_labels = torch.gather(feature_labels.expand(data.size(0), -1),
                                          dim=-1,
                                          index=sim_indices)
                sim_weight = (sim_weight / self.temperature).exp()

                # counts for each class
                one_hot_label = torch.zeros(data.size(0) * k_nn, self.num_cls, device=self.device)
                # [B*K, C]
                one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
                # weighted score ---> [B, C]
                pred_scores = torch.sum(one_hot_label.view(
                    data.size(0), -1, self.num_cls) * sim_weight.unsqueeze(dim=-1), dim=1)

                pred_labels = pred_scores.argsort(dim=-1, descending=True)
                total_top1 += torch.sum(
                    (pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                total_top5 += torch.sum(
                    (pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                
                test_bar.set_description(
                    '{}Test{}  {}Epoch:{} [{}/{}] {}Acc@1: {}{:.2f}% {}Acc@5: {}{:.2f}%'.format(
                    bcolors.OKCYAN, bcolors.ENDC,
                    bcolors.WARNING, bcolors.ENDC,
                    epoch,
                    self.epochs,
                    bcolors.WARNING, bcolors.ENDC,
                    (total_top1 / total_num) * 100,
                    bcolors.WARNING, bcolors.ENDC,
                    (total_top5 / total_num) * 100))

        return (total_top1 / total_num) * 100, (total_top5 / total_num) * 100

__main__.BregmanLoss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet18, resnet50


class Model(nn.Module):
    
    def __init__(self,
                 base_model="resnet18",
                 fc_dim=128,
                 k_subs=10,
                 layer_sizes=[64, 1],
                 use_bn=False,
                 dr_rate=0.2):
        super(Model, self).__init__()
        
        imagenet = True
        resnet_dict = {"resnet18": resnet18(num_classes=fc_dim),
                       "resnet50": resnet50(num_classes=fc_dim)}
        self.model = resnet_dict[base_model]
        dim_mlp = self.model.fc.in_features
        self.model.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.model.fc)

        
        # k subnetworks for bregman
        self.subnets = nn.ModuleList()
        
        for k_idx in range(k_subs):
            fc = nn.Sequential()
            
            for i, (in_size, out_size) in enumerate(zip([fc_dim] + layer_sizes[:-1], layer_sizes)):
                if i + 1 < len(layer_sizes):
                    fc.add_module(
                        name="fc_{:d}_{:d}".format(k_idx, i),
                        module=nn.Linear(in_size, out_size))
                    
                    if use_bn:
                        fc.add_module(
                            name="bn_{:d}_{:d}".format(k_idx, i),
                            module=nn.BatchNorm1d(out_size))
                        
                    fc.add_module(
                        name="relu_{:d}_{:d}".format(k_idx, i),
                        module=nn.ReLU())
                    
                    fc.add_module(
                        name="dp_{:d}_{:d}".format(k_idx, i),
                        module=nn.Dropout(p=dr_rate))

                else:
                    fc.add_module(
                        name="output_{:d}".format(k_idx),
                        module=nn.Linear(in_size, out_size))
                    
                    #fc.add_module(
                    #    name="output_A_{:d}".format(k_idx),
                    #    module=nn.Sigmoid())
                
            self.subnets.append(fc)
            
    def forward(self, x):
        fc_out = self.model(x)
        
        out = []
        for subnet in self.subnets:
            out.append(subnet(fc_out))
        
        out = torch.cat(out, -1)
        #F.normalize(feature, dim=-1)
        return fc_out, out


In [None]:
import argparse
import os
import yaml
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from matplotlib import pyplot as plt
import seaborn as sns
sns.set_theme(style="darkgrid")


from data_aug.data_loader import CustomDataLoader
from model import Model
from trainer import Trainer

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"


def save_config_file(model_checkpoints_folder, args):
    if not os.path.exists(model_checkpoints_folder):
        os.makedirs(model_checkpoints_folder)
    with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
        yaml.dump(args, outfile, default_flow_style=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train SimCLR')
    parser.add_argument('--fc_dim', default=128, type=int, help='Feature dim for latent vector')
    parser.add_argument('--temperature', default=0.1, type=float, help='Temperature used in softmax')
    parser.add_argument('--lmbda', default=5, type=float, help='ratio of contrastive to divergence loss')
    parser.add_argument('--sigma', default=1.5, type=float, help='sigma in gaussian kernel')
    #parser.add_argument('--k_nn', default=200, type=int, help='k in knn')
    parser.add_argument('--batch_size', default=256, type=int, help='batch size')
    parser.add_argument('--epochs', default=400, type=int, help='epochs')
    parser.add_argument('--k_subs', default=500, type=int, help='k subnets')
    parser.add_argument('--layer_size', default=[32, 1], type=int,
                        help='subnetworks layers size (defaut: [64, 1])')
    parser.add_argument('--lr', default=3e-3, type=float,help='initial learning rate')
    parser.add_argument('--wd', default=1e-4, type=float, help='weight decay (default: 1e-4)')
    parser.add_argument('--seed', default=10, type=int, help='seed for initializing training.')
    parser.add_argument('--workers', default=16, type=int, help='number of data loading workers')
    parser.add_argument('--base_model',
                        default='resnet50',
                        help='dataset name',
                        choices=["resnet18", "resnet50"])
    
    parser.add_argument('-dataset-name', default='imagenet', help='dataset name')
    parser.add_argument('--stepwise', action='store_true', help='use stepwise lr schedule')
    # args parse
    args = parser.parse_args()
    base_model = args.base_model
    dataset_name = args.dataset_name
    lr, wd = args.lr, args.wd
    fc_dim, temperature, k_nn = args.fc_dim, args.temperature, args.k_nn
    batch_size, epochs = args.batch_size, args.epochs
    workers = args.workers

    # model setup and optimizer config
    if torch.cuda.is_available():
        args.device = torch.device('cuda')
        #cudnn.deterministic = False
        #cudnn.benchmark = True
    else:
        args.device = torch.device('cpu')
        
    # create a tensorboard writer
    writer = SummaryWriter()
    # save config file
    save_config_file(writer.log_dir, args)
############################################################]
### Load Datasets and Dataloaders
    dl = CustomDataLoader()
    train_loader, memory_loader, test_loader = dl.get_loader(dataset_name, batch_size, workers)
    
    num_cls = len(test_loader.dataset.classes)
    model = Model(base_model=base_model,
                  fc_dim=fc_dim,
                  k_subs=args.k_subs,
                  layer_sizes=args.layer_size,
                  use_bn=True,
                  dr_rate=0.2).to(args.device)
    print(model)
    if torch.cuda.device_count() > 1:
            print("We have available", torch.cuda.device_count(), "GPUs!")
            model = nn.DataParallel(model, device_ids=[0,1,2,3,4,5,6,7])
    
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    if args.stepwise:
        scheduler = MultiStepLR(optimizer, milestones=[120,240], gamma=0.1)
    else:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=epochs,
            eta_min=0, 
            last_epoch=-1)
    
    trainer = Trainer(model,
                      optimizer,
                      scheduler,
                      temperature,
                      num_cls,
                      epochs,
                      sigma,
                      lmbda,
                      args.device)
    
    # training loop
    results = {'train_loss': [],
               'bloss_loss': [],
               'NTXent_loss': [],
               #'test_acc@1': [],
               #'test_acc@5': []
              }
    save_name_pre = '{}_K{}_{}_{}_{}_{}_{}_{}_{}'.format(
        dataset_name, args.k_subs,
        base_model, lr,
        fc_dim, temperature,
        k_nn, batch_size, epochs)
    csv_dir = os.path.join(writer.log_dir, '{}_stats.csv'.format(save_name_pre))
    model_dir = os.path.join(writer.log_dir, '{}_model.pth'.format(save_name_pre))
    fig_dir = os.path.join(writer.log_dir, '{}_loss_acc.png'.format(save_name_pre))
    
    best_acc = 0.0
    for epoch in range(1, epochs + 1):
        train_loss, bloss, NTXent = trainer.train(train_loader, epoch)
        results['train_loss'].append(train_loss)
        results['bloss_loss'].append(bloss)
        results['NTXent_loss'].append(NTXent)
        writer.add_scalar('loss/train', results['train_loss'][-1], epoch)
        
        #test_acc_1, test_acc_5 = trainer.test(memory_loader, test_loader, k_nn, epoch)
        #results['test_acc@1'].append(test_acc_1)
        #results['test_acc@5'].append(test_acc_5)
        #writer.add_scalar('acc@1/test', results['test_acc@1'][-1], epoch)
        #writer.add_scalar('acc@5/test', results['test_acc@5'][-1], epoch)
        
        # save statistics
        data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
        data_frame.to_csv(csv_dir, index_label='epoch')
        
        if isinstance(model, nn.DataParallel):
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()
        torch.save(state_dict, model_dir)
        
        #if test_acc_1 > best_acc:
        #    best_acc = test_acc_1
        #    if isinstance(model, nn.DataParallel):
        #        state_dict = model.module.state_dict()
        #    else:
        #        state_dict = model.state_dict()
        #    torch.save(state_dict, model_dir)
    
    # plotting loss and accuracies
    #df = pd.read_csv(csv_dir)
    #fig, axes = plt.subplots(1, 3, sharex=True, figsize=(20,5))
    #axes[0].set_title('Loss/Train')
    #axes[1].set_title('acc@1/test')
    #axes[2].set_title('acc@5/test')
    #sns.lineplot(ax=axes[0], x="epoch", y="train_loss", data=df)
    #sns.lineplot(ax=axes[1], x="epoch", y="test_acc@1", data=df)
    #sns.lineplot(ax=axes[2], x="epoch", y="test_acc@5", data=df)
    
    #fig.savefig(fig_dir)
    
    
