In [1]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tqdm.auto import tqdm
import torch.nn.functional as F
import Models.Network as models
from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Loading Dataset

dataroot = r"/data4/home/hrishikeshj/DLNLP/ADRLA2/Dataset/train"
image_size = 128
batch_size = 64

dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=8, drop_last=True)

In [3]:
dim_of_emb = 128
device_to_use = 1

In [4]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class projection_MLP(nn.Module):
    def __init__(self):
        super(projection_MLP, self).__init__()

        n_channels = 512

        self.projection_head = nn.Sequential()

        self.projection_head.add_module('W1', nn.Linear(
            n_channels, n_channels))
        self.projection_head.add_module('ReLU', nn.ReLU())
        self.projection_head.add_module('W2', nn.Linear(
            n_channels, 128))

    def forward(self, x):
        return self.projection_head(x)

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=128, zero_init_residual=False,
                 groups=1, width_per_group=64):
        super(ResNet, self).__init__()
        
        self._norm_layer = nn.BatchNorm2d

        self.register_buffer

        self.inplanes = 64
        self.dilation = 1
        
        self.groups = groups
        self.base_width = width_per_group

        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                                   bias=False)

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

            

        self.bn1 = self._norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc = nn.Linear(512 * block.expansion,num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                # if isinstance(m, Bottleneck):
                #     nn.init.constant_(m.bn3.weight, 0)
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = torch.flatten(x, 1)

        x = self.fc(x)

        return x

In [5]:
queue_size = 128*batch_size
momentum = 0.999
temperature = 0.07

learning_rate = 0.03
sgd_weight_decay = 0.0001
sgd_momentum = 0.9
num_epochs = 100

In [6]:
class MoCo_Model(nn.Module):
    def __init__(self):
        
        super(MoCo_Model, self).__init__()

        self.queue_size = queue_size
        self.momentum = momentum
        self.temperature = temperature

        assert self.queue_size % batch_size == 0  # for simplicity

        # Load model
        self.query_encoder = ResNet(BasicBlock,[2,2,2,2],num_classes=dim_of_emb)
        self.key_encoder = ResNet(BasicBlock,[2,2,2,2],num_classes=dim_of_emb)

        self.query_encoder.fc = projection_MLP()
        self.key_encoder.fc = projection_MLP()


        # Initialize the key encoder to have the same values as query encoder
        # Do not update the key encoder via gradient
        for query_param, key_param in zip(self.query_encoder.parameters(), self.key_encoder.parameters()):
            key_param.data.copy_(query_param.data)
            key_param.requires_grad = False

        # Create the queue to store negative samples
        self.register_buffer("queue", torch.randn(self.queue_size, dim_of_emb))

        # Create pointer to store current position in the queue when enqueue and dequeue
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def momentum_update(self):
        
        # For each of the parameters in each encoder
        for p_q, p_k in zip(self.query_encoder.parameters(), self.key_encoder.parameters()):
            p_k.data = p_k.data * self.momentum + p_q.detach().data * (1. - self.momentum)

    @torch.no_grad()
    def shuffled_idx(self, batch_size):
        '''
        Generation of the shuffled indexes for the implementation of ShuffleBN.

        https://github.com/HobbitLong/CMC.

        args:
            batch_size (Tensor.int()):  Number of samples in a batch

        returns:
            shuffled_idxs (Tensor.long()): A random permutation index order for the shuffling of the current minibatch

            reverse_idxs (Tensor.long()): A reverse of the random permutation index order for the shuffling of the
                                            current minibatch to get back original sample order

        '''

        # Generate shuffled indexes
        shuffled_idxs = torch.randperm(batch_size).long().cuda()

        reverse_idxs = torch.zeros(batch_size).long().cuda()

        value = torch.arange(batch_size).long().cuda()

        reverse_idxs.index_copy_(0, shuffled_idxs, value)

        return shuffled_idxs, reverse_idxs

    @torch.no_grad()
    def update_queue(self, feat_k):

        curr_BS = feat_k.size(0)

        ptr = int(self.queue_ptr)

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[ptr:ptr + curr_BS, :] = feat_k

        # move pointer along to end of current batch
        ptr = (ptr + curr_BS) % self.queue_size

        # Store queue pointer as register_buffer
        self.queue_ptr[0] = ptr

    def InfoNCE_logits(self, f_q, f_k):

        f_k = f_k.detach()

        # Get queue from register_buffer
        f_mem = self.queue.clone().detach()

        # Normalize the feature representations
        f_q = nn.functional.normalize(f_q, dim=1)
        f_k = nn.functional.normalize(f_k, dim=1)
        f_mem = nn.functional.normalize(f_mem, dim=1)

        # Compute sim between positive views
        pos = torch.bmm(f_q.view(f_q.size(0), 1, -1),
                        f_k.view(f_k.size(0), -1, 1)).squeeze(-1)

        # Compute sim between postive and all negatives in the memory
        neg = torch.mm(f_q, f_mem.transpose(1, 0))

        logits = torch.cat((pos, neg), dim=1)

        logits /= self.temperature

        # Create labels, first logit is postive, all others are negative
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        return logits, labels

    def forward(self, x_q, x_k):

        batch_size = x_q.size(0)

        # Feature representations of the query view from the query encoder
        feat_q = self.query_encoder(x_q)

        # TODO: shuffle ids with distributed data parallel
        # Get shuffled and reversed indexes for the current minibatch
        shuffled_idxs, reverse_idxs = self.shuffled_idx(batch_size)

        with torch.no_grad():
            # Update the key encoder
            self.momentum_update()

            # Shuffle minibatch
            x_k = x_k[shuffled_idxs]

            # Feature representations of the shuffled key view from the key encoder
            feat_k = self.key_encoder(x_k)

            # reverse the shuffled samples to original position
            feat_k = feat_k[reverse_idxs]

        # Compute the logits for the InfoNCE contrastive loss.
        logit, label = self.InfoNCE_logits(feat_q, feat_k)

        # Update the queue/memory with the current key_encoder minibatch.
        self.update_queue(feat_k)

        return logit, label

In [7]:
mocoModel = MoCo_Model()

In [8]:
torch.cuda.set_device(device_to_use)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [9]:
mocoModel = mocoModel.to(device=device)

In [10]:
optimiser = torch.optim.SGD(mocoModel.parameters(), lr=learning_rate, weight_decay=sgd_weight_decay, momentum=sgd_momentum)

In [11]:
lr_decay = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, num_epochs)

In [12]:
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

In [13]:
augmentation = [
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]

augTransform = transforms.Compose(augmentation)

In [14]:
#Pretrain
number_of_steps = len(dataloader)*num_epochs
progress_bar = tqdm(range(number_of_steps))

mocoModel.train()
for epoch in range(num_epochs):
    for i,(images,labels) in enumerate(dataloader):
        optimiser.zero_grad()
        #images : [bs,3,128,128]
        images = images.to(device)
        #Get two augmentations:
        query_images = augTransform(images)
        key_images = augTransform(images)

        all_logits, labels = mocoModel(query_images,key_images)
        #all_logits : [(queue_size + 1),emb_dim]
        #labels : [(queue_size + 1)], Tensor of all zeros

        loss = criterion(all_logits,labels)

        loss.backward()

        optimiser.step()

        progress_bar.update()

        if i%500 == 0:
            print("Pretrain_loss : ",loss.item())

    lr_decay.step()


  0%|          | 3/22800 [00:33<54:33:09,  8.61s/it] 

Pretrain_loss :  0.02857186272740364


  1%|          | 230/22800 [02:23<2:03:29,  3.05it/s]

Pretrain_loss :  5.432933807373047


  2%|▏         | 456/22800 [02:42<29:20, 12.69it/s]  

Pretrain_loss :  4.812164306640625


  3%|▎         | 688/22800 [03:28<2:04:57,  2.95it/s]

Pretrain_loss :  3.1110713481903076


  4%|▍         | 916/22800 [04:13<1:27:14,  4.18it/s]

Pretrain_loss :  2.9676005840301514


  5%|▌         | 1143/22800 [04:24<45:57,  7.85it/s] 

Pretrain_loss :  2.0683467388153076


  6%|▌         | 1373/22800 [04:38<1:11:56,  4.96it/s]

Pretrain_loss :  1.715352177619934


  7%|▋         | 1601/22800 [04:53<36:35,  9.66it/s]  

Pretrain_loss :  1.6454323530197144


  8%|▊         | 1825/22800 [05:38<2:04:50,  2.80it/s]

Pretrain_loss :  1.3108285665512085


  9%|▉         | 2053/22800 [06:21<2:16:53,  2.53it/s]

Pretrain_loss :  1.1967517137527466


 10%|█         | 2284/22800 [07:08<1:28:18,  3.87it/s]

Pretrain_loss :  1.2320719957351685


 11%|█         | 2511/22800 [07:51<1:03:24,  5.33it/s]

Pretrain_loss :  0.9905940890312195


 12%|█▏        | 2742/22800 [08:03<34:35,  9.66it/s]  

Pretrain_loss :  1.044512391090393


 13%|█▎        | 2968/22800 [08:16<44:31,  7.42it/s]

Pretrain_loss :  1.212280511856079


 14%|█▍        | 3195/22800 [09:00<1:58:22,  2.76it/s]

Pretrain_loss :  0.9807380437850952


 15%|█▌        | 3423/22800 [09:44<1:26:57,  3.71it/s]

Pretrain_loss :  0.6864845156669617


 16%|█▌        | 3649/22800 [10:04<1:06:00,  4.84it/s]

Pretrain_loss :  0.7475559711456299


 17%|█▋        | 3880/22800 [10:53<1:49:37,  2.88it/s]

Pretrain_loss :  0.7188102602958679


 18%|█▊        | 4109/22800 [11:36<1:05:31,  4.75it/s]

Pretrain_loss :  0.6822733283042908


 19%|█▉        | 4337/22800 [11:48<34:12,  9.00it/s]  

Pretrain_loss :  0.6863484382629395


 20%|██        | 4565/22800 [12:00<38:00,  8.00it/s]

Pretrain_loss :  0.5740864276885986


 21%|██        | 4791/22800 [12:12<39:08,  7.67it/s]

Pretrain_loss :  0.5262278914451599


 22%|██▏       | 5017/22800 [12:23<38:02,  7.79it/s]

Pretrain_loss :  0.5228994488716125


 23%|██▎       | 5247/22800 [13:06<2:24:42,  2.02it/s]

Pretrain_loss :  0.8031159043312073


 24%|██▍       | 5475/22800 [13:48<1:23:33,  3.46it/s]

Pretrain_loss :  0.5473595857620239


 25%|██▌       | 5703/22800 [14:28<1:16:52,  3.71it/s]

Pretrain_loss :  0.5799432992935181


 26%|██▌       | 5932/22800 [14:44<59:14,  4.75it/s]  

Pretrain_loss :  0.5977439284324646


 27%|██▋       | 6161/22800 [15:27<1:00:23,  4.59it/s]

Pretrain_loss :  0.4535023272037506


 28%|██▊       | 6387/22800 [15:58<1:11:02,  3.85it/s]

Pretrain_loss :  0.6267112493515015


 29%|██▉       | 6615/22800 [16:18<50:07,  5.38it/s]  

Pretrain_loss :  0.3857680857181549


 30%|███       | 6842/22800 [16:38<52:49,  5.04it/s]

Pretrain_loss :  0.43619000911712646


 31%|███       | 7070/22800 [16:57<51:28,  5.09it/s]

Pretrain_loss :  0.45538416504859924


 32%|███▏      | 7299/22800 [17:15<28:24,  9.10it/s]

Pretrain_loss :  0.4143941104412079


 33%|███▎      | 7526/22800 [17:29<1:11:23,  3.57it/s]

Pretrain_loss :  0.3596288561820984


 34%|███▍      | 7753/22800 [18:13<1:31:06,  2.75it/s]

Pretrain_loss :  0.5080121755599976


 35%|███▌      | 7985/22800 [18:48<31:51,  7.75it/s]  

Pretrain_loss :  0.3400283455848694


 36%|███▌      | 8214/22800 [19:00<25:58,  9.36it/s]

Pretrain_loss :  0.4514525830745697


 37%|███▋      | 8439/22800 [19:11<33:14,  7.20it/s]

Pretrain_loss :  0.42034757137298584


 38%|███▊      | 8667/22800 [19:23<34:01,  6.92it/s]

Pretrain_loss :  0.45257025957107544


 39%|███▉      | 8894/22800 [19:37<48:38,  4.76it/s]

Pretrain_loss :  0.37456637620925903


 40%|████      | 9123/22800 [20:25<2:03:45,  1.84it/s]

Pretrain_loss :  0.34982964396476746


 41%|████      | 9352/22800 [21:06<57:28,  3.90it/s]  

Pretrain_loss :  0.33774200081825256


 42%|████▏     | 9578/22800 [21:42<1:13:27,  3.00it/s]

Pretrain_loss :  0.2659962773323059


 43%|████▎     | 9807/22800 [21:58<59:56,  3.61it/s]  

Pretrain_loss :  0.30895504355430603


 44%|████▍     | 10033/22800 [22:39<58:38,  3.63it/s]

Pretrain_loss :  0.27158981561660767


 45%|████▌     | 10263/22800 [23:27<1:49:31,  1.91it/s]

Pretrain_loss :  0.38778960704803467


 46%|████▌     | 10493/22800 [23:45<24:10,  8.48it/s]  

Pretrain_loss :  0.35142460465431213


 47%|████▋     | 10721/22800 [23:57<22:24,  8.99it/s]

Pretrain_loss :  0.3041897416114807


 48%|████▊     | 10946/22800 [24:08<23:49,  8.29it/s]

Pretrain_loss :  0.2734185755252838


 49%|████▉     | 11179/22800 [24:22<29:40,  6.53it/s]

Pretrain_loss :  0.2829796075820923


 50%|█████     | 11402/22800 [25:06<57:53,  3.28it/s]

Pretrain_loss :  0.30676722526550293


 51%|█████     | 11632/22800 [25:18<19:57,  9.33it/s]

Pretrain_loss :  0.34706705808639526


 52%|█████▏    | 11862/22800 [25:30<26:44,  6.82it/s]

Pretrain_loss :  0.264839768409729


 53%|█████▎    | 12089/22800 [25:43<37:26,  4.77it/s]

Pretrain_loss :  0.30805954337120056


 54%|█████▍    | 12315/22800 [26:28<44:19,  3.94it/s]

Pretrain_loss :  0.3121994435787201


 55%|█████▌    | 12542/22800 [26:47<29:43,  5.75it/s]

Pretrain_loss :  0.2698330283164978


 56%|█████▌    | 12769/22800 [27:30<51:44,  3.23it/s]

Pretrain_loss :  0.3213063180446625


 57%|█████▋    | 13002/22800 [27:42<17:26,  9.37it/s]

Pretrain_loss :  0.21232540905475616


 58%|█████▊    | 13228/22800 [27:54<17:25,  9.16it/s]

Pretrain_loss :  0.35268789529800415


 59%|█████▉    | 13455/22800 [28:06<17:44,  8.78it/s]

Pretrain_loss :  0.45303910970687866


 60%|██████    | 13683/22800 [28:17<16:23,  9.27it/s]

Pretrain_loss :  0.26786649227142334


 61%|██████    | 13911/22800 [28:30<19:10,  7.73it/s]

Pretrain_loss :  0.28472304344177246


 62%|██████▏   | 14140/22800 [28:44<31:57,  4.52it/s]

Pretrain_loss :  0.31563007831573486


 63%|██████▎   | 14366/22800 [29:30<45:08,  3.11it/s]

Pretrain_loss :  0.2795591950416565


 64%|██████▍   | 14594/22800 [29:52<54:05,  2.53it/s]  

Pretrain_loss :  0.24306200444698334


 65%|██████▌   | 14822/22800 [30:38<37:53,  3.51it/s]

Pretrain_loss :  0.19499675929546356


 66%|██████▌   | 15048/22800 [30:49<04:53, 26.38it/s]

Pretrain_loss :  0.3647770881652832


 67%|██████▋   | 15280/22800 [31:33<30:34,  4.10it/s]

Pretrain_loss :  0.2503824532032013


 68%|██████▊   | 15507/22800 [32:16<33:44,  3.60it/s]

Pretrain_loss :  0.2896762192249298


 69%|██████▉   | 15735/22800 [32:29<16:43,  7.04it/s]

Pretrain_loss :  0.3267640173435211


 70%|███████   | 15962/22800 [32:42<14:59,  7.60it/s]

Pretrain_loss :  0.3271937668323517


 71%|███████   | 16191/22800 [33:02<38:30,  2.86it/s]

Pretrain_loss :  0.2367960512638092


 72%|███████▏  | 16419/22800 [33:47<24:52,  4.27it/s]

Pretrain_loss :  0.20355947315692902


 73%|███████▎  | 16646/22800 [34:27<33:26,  3.07it/s]

Pretrain_loss :  0.2637384235858917


 74%|███████▍  | 16875/22800 [35:10<29:46,  3.32it/s]

Pretrain_loss :  0.24255460500717163


 75%|███████▌  | 17102/22800 [35:55<25:26,  3.73it/s]

Pretrain_loss :  0.2367638796567917


 76%|███████▌  | 17330/22800 [36:31<25:53,  3.52it/s]

Pretrain_loss :  0.27130815386772156


 77%|███████▋  | 17558/22800 [37:16<42:02,  2.08it/s]

Pretrain_loss :  0.2813325822353363


 78%|███████▊  | 17787/22800 [38:01<28:34,  2.92it/s]

Pretrain_loss :  0.24158549308776855


 79%|███████▉  | 18012/22800 [38:13<03:43, 21.41it/s]

Pretrain_loss :  0.26637107133865356


 80%|████████  | 18243/22800 [38:35<21:01,  3.61it/s]

Pretrain_loss :  0.27872592210769653


 81%|████████  | 18471/22800 [39:17<16:40,  4.32it/s]

Pretrain_loss :  0.33952081203460693


 82%|████████▏ | 18697/22800 [39:30<08:38,  7.92it/s]

Pretrain_loss :  0.33644068241119385


 83%|████████▎ | 18927/22800 [40:12<20:09,  3.20it/s]

Pretrain_loss :  0.2937498986721039


 84%|████████▍ | 19156/22800 [40:57<13:52,  4.38it/s]

Pretrain_loss :  0.22684931755065918


 85%|████████▌ | 19383/22800 [41:11<16:16,  3.50it/s]

Pretrain_loss :  0.24903248250484467


 86%|████████▌ | 19610/22800 [41:54<13:10,  4.03it/s]

Pretrain_loss :  0.31990760564804077


 87%|████████▋ | 19836/22800 [42:38<07:10,  6.88it/s]

Pretrain_loss :  0.2068905383348465


 88%|████████▊ | 20067/22800 [43:25<12:11,  3.74it/s]

Pretrain_loss :  0.23989754915237427


 89%|████████▉ | 20295/22800 [43:45<06:04,  6.87it/s]

Pretrain_loss :  0.3143938183784485


 90%|█████████ | 20523/22800 [43:57<05:12,  7.28it/s]

Pretrain_loss :  0.2616685628890991


 91%|█████████ | 20752/22800 [44:09<03:06, 10.98it/s]

Pretrain_loss :  0.223558709025383


 92%|█████████▏| 20982/22800 [44:20<03:14,  9.33it/s]

Pretrain_loss :  0.18972350656986237


 93%|█████████▎| 21205/22800 [44:32<03:17,  8.08it/s]

Pretrain_loss :  0.17588089406490326


 94%|█████████▍| 21437/22800 [45:16<05:58,  3.81it/s]

Pretrain_loss :  0.2660849094390869


 95%|█████████▌| 21664/22800 [46:00<03:53,  4.86it/s]

Pretrain_loss :  0.23346643149852753


 96%|█████████▌| 21890/22800 [46:19<04:29,  3.37it/s]

Pretrain_loss :  0.33552953600883484


 97%|█████████▋| 22119/22800 [46:42<03:31,  3.22it/s]

Pretrain_loss :  0.34154796600341797


 98%|█████████▊| 22347/22800 [47:27<02:12,  3.42it/s]

Pretrain_loss :  0.27680158615112305


 99%|█████████▉| 22577/22800 [48:14<01:11,  3.13it/s]

Pretrain_loss :  0.211419016122818


100%|█████████▉| 22799/22800 [48:56<00:00,  5.73it/s]

100%|██████████| 22800/22800 [49:10<00:00,  5.73it/s]

In [17]:
ModelPath = "/data4/home/hrishikeshj/DLNLP/ADRLA3/trainedModels/TR1/mocoModel.pt"

In [18]:
torch.save(mocoModel.state_dict(),ModelPath)

In [15]:
augTransform = transforms.Compose(augmentation)


In [17]:
for onesample in dataloader:
    break

In [18]:
imageBatch = onesample[0].to(device)

In [19]:
x_q = augTransform(imageBatch)



In [20]:
x_k = augTransform(imageBatch)

In [21]:
op = mocoModel(x_q,x_k)

In [23]:
op[0].shape

torch.Size([64, 8193])

In [25]:
loss = criterion(op[0],op[1])

In [29]:
op[0]

tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       grad_fn=<DivBackward0>)