In [1]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tqdm.auto import tqdm
import torch.nn.functional as F
import Models.Network as models
from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Loading Dataset

dataroot = r"/data4/home/hrishikeshj/DLNLP/ADRLA2/Dataset/train"
image_size = 128
batch_size = 64

dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

testDataPath = r"/data4/home/hrishikeshj/DLNLP/ADRLA2/Dataset/val"

test_dataset = dset.ImageFolder(root=testDataPath,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32,
                                         shuffle=True, num_workers=8, drop_last=True)

In [3]:
dim_of_emb = 128
device_to_use = 1

In [4]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class projection_MLP(nn.Module):
    def __init__(self):
        super(projection_MLP, self).__init__()

        n_channels = 512

        self.projection_head = nn.Sequential()

        self.projection_head.add_module('W1', nn.Linear(
            n_channels, n_channels))
        self.projection_head.add_module('ReLU', nn.ReLU())
        self.projection_head.add_module('W2', nn.Linear(
            n_channels, 128))

    def forward(self, x):
        return self.projection_head(x)

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=128, zero_init_residual=False,
                 groups=1, width_per_group=64):
        super(ResNet, self).__init__()
        
        self._norm_layer = nn.BatchNorm2d

        self.register_buffer

        self.inplanes = 64
        self.dilation = 1
        
        self.groups = groups
        self.base_width = width_per_group

        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                                   bias=False)

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

            

        self.bn1 = self._norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc = nn.Linear(512 * block.expansion,num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                # if isinstance(m, Bottleneck):
                #     nn.init.constant_(m.bn3.weight, 0)
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = torch.flatten(x, 1)

        x = self.fc(x)

        return x

In [5]:
queue_size = 128*batch_size
momentum = 0.999
temperature = 0.07

learning_rate = 0.03
sgd_weight_decay = 0.0001
sgd_momentum = 0.9
num_epochs = 100
num_classes = 3

In [6]:
class MoCo_Model(nn.Module):
    def __init__(self):
        
        super(MoCo_Model, self).__init__()

        self.queue_size = queue_size
        self.momentum = momentum
        self.temperature = temperature

        assert self.queue_size % batch_size == 0  # for simplicity

        # Load model
        self.query_encoder = ResNet(BasicBlock,[2,2,2,2],num_classes=dim_of_emb)
        self.key_encoder = ResNet(BasicBlock,[2,2,2,2],num_classes=dim_of_emb)

        self.query_encoder.fc = projection_MLP()
        self.key_encoder.fc = projection_MLP()


        # Initialize the key encoder to have the same values as query encoder
        # Do not update the key encoder via gradient
        for query_param, key_param in zip(self.query_encoder.parameters(), self.key_encoder.parameters()):
            key_param.data.copy_(query_param.data)
            key_param.requires_grad = False

        # Create the queue to store negative samples
        self.register_buffer("queue", torch.randn(self.queue_size, dim_of_emb))

        # Create pointer to store current position in the queue when enqueue and dequeue
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def momentum_update(self):
        
        # For each of the parameters in each encoder
        for p_q, p_k in zip(self.query_encoder.parameters(), self.key_encoder.parameters()):
            p_k.data = p_k.data * self.momentum + p_q.detach().data * (1. - self.momentum)

    @torch.no_grad()
    def shuffled_idx(self, batch_size):
        '''
        Generation of the shuffled indexes for the implementation of ShuffleBN.

        https://github.com/HobbitLong/CMC.

        args:
            batch_size (Tensor.int()):  Number of samples in a batch

        returns:
            shuffled_idxs (Tensor.long()): A random permutation index order for the shuffling of the current minibatch

            reverse_idxs (Tensor.long()): A reverse of the random permutation index order for the shuffling of the
                                            current minibatch to get back original sample order

        '''

        # Generate shuffled indexes
        shuffled_idxs = torch.randperm(batch_size).long().cuda()

        reverse_idxs = torch.zeros(batch_size).long().cuda()

        value = torch.arange(batch_size).long().cuda()

        reverse_idxs.index_copy_(0, shuffled_idxs, value)

        return shuffled_idxs, reverse_idxs

    @torch.no_grad()
    def update_queue(self, feat_k):

        curr_BS = feat_k.size(0)

        ptr = int(self.queue_ptr)

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[ptr:ptr + curr_BS, :] = feat_k

        # move pointer along to end of current batch
        ptr = (ptr + curr_BS) % self.queue_size

        # Store queue pointer as register_buffer
        self.queue_ptr[0] = ptr

    def InfoNCE_logits(self, f_q, f_k):

        f_k = f_k.detach()

        # Get queue from register_buffer
        f_mem = self.queue.clone().detach()

        # Normalize the feature representations
        f_q = nn.functional.normalize(f_q, dim=1)
        f_k = nn.functional.normalize(f_k, dim=1)
        f_mem = nn.functional.normalize(f_mem, dim=1)

        # Compute sim between positive views
        pos = torch.bmm(f_q.view(f_q.size(0), 1, -1),
                        f_k.view(f_k.size(0), -1, 1)).squeeze(-1)

        # Compute sim between postive and all negatives in the memory
        neg = torch.mm(f_q, f_mem.transpose(1, 0))

        logits = torch.cat((pos, neg), dim=1)

        logits /= self.temperature

        # Create labels, first logit is postive, all others are negative
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        return logits, labels

    def forward(self, x_q, x_k):

        batch_size = x_q.size(0)

        # Feature representations of the query view from the query encoder
        feat_q = self.query_encoder(x_q)

        # TODO: shuffle ids with distributed data parallel
        # Get shuffled and reversed indexes for the current minibatch
        shuffled_idxs, reverse_idxs = self.shuffled_idx(batch_size)

        with torch.no_grad():
            # Update the key encoder
            self.momentum_update()

            # Shuffle minibatch
            x_k = x_k[shuffled_idxs]

            # Feature representations of the shuffled key view from the key encoder
            feat_k = self.key_encoder(x_k)

            # reverse the shuffled samples to original position
            feat_k = feat_k[reverse_idxs]

        # Compute the logits for the InfoNCE contrastive loss.
        logit, label = self.InfoNCE_logits(feat_q, feat_k)

        # Update the queue/memory with the current key_encoder minibatch.
        self.update_queue(feat_k)

        return logit, label

In [7]:
query_encoder = ResNet(BasicBlock,[2,2,2,2],num_classes=3)

In [8]:
modelPath = "/data4/home/hrishikeshj/DLNLP/ADRLA3/trainedModels/TR1/mocoModel.pt"
moco_state_dict = torch.load(modelPath)

In [9]:
for k in list(moco_state_dict.keys()):
    # retain only encoder_q up to before the embedding layer
    if k.startswith('query_encoder.') and not k.startswith('query_encoder.fc'):
        # remove prefix
        moco_state_dict[k[len("query_encoder."):]] = moco_state_dict[k]
    # delete renamed or unused k
    del moco_state_dict[k]

In [10]:
query_encoder.load_state_dict(moco_state_dict,strict=False)

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

In [11]:
torch.cuda.set_device(device_to_use)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [12]:
query_encoder = query_encoder.to(device)

In [13]:
class_wise_indices = {0:[i for i in range(5153)],1:[i for i in range(5153,9892)],2:[i for i in range(9892,14629)]}


percentage = 0.6

subset_indices = []

for label in range(3):
    num_of_samples = len(class_wise_indices[label])
    num_of_samples = round(percentage*num_of_samples)
    subset_indices.extend(class_wise_indices[label][:num_of_samples])

print("After sampling : ",len(subset_indices))

After sampling :  8777


In [14]:
import random

In [15]:
random.shuffle(subset_indices)

In [16]:
subset_sampler = torch.utils.data.SubsetRandomSampler(subset_indices)

In [17]:
batch_size = 32
learning_rate = 0.001
fine_tune_epochs = 200

In [18]:
subset_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size
                                         , num_workers=8, drop_last=True,sampler=subset_sampler)

In [19]:
optimiser = torch.optim.Adam(query_encoder.parameters(), lr=learning_rate)

In [20]:
lr_decay = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, num_epochs)

In [21]:
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

In [22]:
patience = 40

In [23]:
def evaluate(encoder, test_dataloader):
    valid_loss = 0.0  # reset loss
    valid_acc = 0.0  # reset acc
    
    criterion = nn.CrossEntropyLoss().cuda()

    # Evaluate both encoder and class head
    encoder.eval()

    for i, (inputs, target) in enumerate(test_dataloader):

        # Do not compute gradient for encoder and classification head
        encoder.zero_grad()
        inputs = inputs.to(device)
        target = target.to(device)
        # Forward pass

        output = encoder(inputs)

        loss = criterion(output, target)

        valid_loss += loss.item()

        predicted = output.argmax(-1)

        acc = (predicted == target).sum().item() / target.size(0)

        valid_acc += acc

    valid_loss = valid_loss / len(test_dataloader)

    valid_acc = valid_acc / len(test_dataloader)

    print("Test loss : ",valid_loss)
    print("Test Acc : ",valid_acc)
    return valid_loss, valid_acc

In [24]:
model_storage_path = "/data4/home/hrishikeshj/DLNLP/ADRLA3/FineTunedModels/FT10Perc.pt"
progress_bar = tqdm(range(fine_tune_epochs*len(subset_dataloader)))
best_test_acc = 0.0
best_epoch = 0
patience_counter = 0
for epoch in range(fine_tune_epochs):

    # Freeze the encoder, train classification head
    query_encoder.eval()
    run_loss = 0.0
    run_acc = 0.0

    

    for i, (inputs, target) in enumerate(subset_dataloader):

        inputs = inputs.to(device)
        target = target.to(device)

        # Forward pass
        optimiser.zero_grad()

        # Do not compute the gradients for the frozen encoder
        output = query_encoder(inputs)

        # Take pretrained encoder representations
        loss = criterion(output, target)

        loss.backward()

        optimiser.step()

        run_loss += loss.item()

        predicted = output.argmax(1)

        acc = (predicted == target).sum().item() / target.size(0)

        run_acc += acc

        progress_bar.update(1)

    epoch_finetune_loss = run_loss / len(subset_dataloader)  # sample_count

    epoch_finetune_acc = run_acc / len(subset_dataloader)

    ''' Update Schedulers '''
    # Decay lr with CosineAnnealingLR
    lr_decay.step()

    ''' Printing '''
    print("Epoch loss : ",epoch_finetune_loss)
    print("Epoch accuracy : ",epoch_finetune_acc)

    test_valid_loss, test_valid_acc = evaluate(query_encoder, test_dataloader)

    if test_valid_acc>best_test_acc:
        patience_counter = 0
        best_test_acc = test_valid_acc
        best_epoch = epoch + 1
        torch.save(query_encoder.state_dict(),model_storage_path)
    else:
        patience_counter += 1
        if patience_counter>=patience:
            print("Stopping early as no improvement since 20 epochs")
            break
        else:
            print("Patience Counter : ",patience_counter)

    
    epoch_finetune_loss = None  # reset loss
    epoch_finetune_acc = None

    torch.cuda.empty_cache()

print("Best Epoch : ",best_epoch)
print("Best test Accuracy : ",best_test_acc)


  0%|          | 0/54800 [00:00<?, ?it/s]

  0%|          | 272/54800 [00:35<36:40, 24.77it/s]  

Epoch loss :  1.1295127868652344
Epoch accuracy :  0.3418111313868613
Test loss :  1.0978111210076704
Test Acc :  0.33695652173913043


  1%|          | 547/54800 [01:17<1:15:44, 11.94it/s] 

Epoch loss :  0.900285028410654
Epoch accuracy :  0.5258895985401459
Test loss :  0.5573484333960906
Test Acc :  0.782608695652174


  1%|▏         | 821/54800 [02:03<1:38:37,  9.12it/s] 

Epoch loss :  0.41279507983122427
Epoch accuracy :  0.839530109489051
Test loss :  0.34538228097169293
Test Acc :  0.8729619565217391


  2%|▏         | 1096/54800 [02:49<1:29:31, 10.00it/s]

Epoch loss :  0.26855550469817036
Epoch accuracy :  0.895529197080292
Test loss :  0.2792699209049992
Test Acc :  0.8899456521739131


  2%|▏         | 1369/54800 [03:34<1:16:10, 11.69it/s] 

Epoch loss :  0.2037088214453772
Epoch accuracy :  0.9265510948905109
Test loss :  0.2320314585838629
Test Acc :  0.9150815217391305


  3%|▎         | 1643/54800 [04:19<1:22:04, 10.79it/s] 

Epoch loss :  0.16539036072654664
Epoch accuracy :  0.9382983576642335
Test loss :  0.20228466448252616
Test Acc :  0.9245923913043478


  3%|▎         | 1917/54800 [05:06<1:22:38, 10.67it/s] 

Epoch loss :  0.11617894037660674
Epoch accuracy :  0.9563184306569343
Test loss :  0.18717313266318777
Test Acc :  0.9368206521739131


  4%|▍         | 2192/54800 [05:56<1:15:53, 11.55it/s] 

Epoch loss :  0.08105852424989651
Epoch accuracy :  0.9717153284671532
Test loss :  0.21265550731154886
Test Acc :  0.9266304347826086
Patience Counter :  1


  4%|▍         | 2465/54800 [06:44<1:18:15, 11.15it/s] 

Epoch loss :  0.06626833561722047
Epoch accuracy :  0.9767335766423357
Test loss :  0.172759461277367
Test Acc :  0.951766304347826


  5%|▌         | 2740/54800 [07:28<1:10:31, 12.30it/s] 

Epoch loss :  0.05059865677544854
Epoch accuracy :  0.9819799270072993
Test loss :  0.2609726901611556
Test Acc :  0.9375
Patience Counter :  1


  5%|▌         | 3013/54800 [08:20<1:11:32, 12.06it/s] 

Epoch loss :  0.05047193981266715
Epoch accuracy :  0.9819799270072993
Test loss :  0.30226685693892447
Test Acc :  0.9225543478260869
Patience Counter :  2


  6%|▌         | 3287/54800 [09:07<1:19:26, 10.81it/s] 

Epoch loss :  0.040376639867479316
Epoch accuracy :  0.985059306569343
Test loss :  0.31979483544178633
Test Acc :  0.904891304347826
Patience Counter :  3


  6%|▋         | 3562/54800 [09:52<1:07:07, 12.72it/s] 

Epoch loss :  0.043541045481451476
Epoch accuracy :  0.9858576642335767
Test loss :  0.2348329950147435
Test Acc :  0.9341032608695652
Patience Counter :  4


  7%|▋         | 3836/54800 [10:37<1:45:44,  8.03it/s] 

Epoch loss :  0.03313281161703955
Epoch accuracy :  0.9891651459854015
Test loss :  0.20524870326904499
Test Acc :  0.9375
Patience Counter :  5


  7%|▋         | 4109/54800 [11:21<1:17:34, 10.89it/s] 

Epoch loss :  0.022993587187586716
Epoch accuracy :  0.9920164233576643
Test loss :  0.28868804115842545
Test Acc :  0.9300271739130435
Patience Counter :  6


  8%|▊         | 4384/54800 [11:59<1:08:16, 12.31it/s] 

Epoch loss :  0.03218088729055743
Epoch accuracy :  0.9911040145985401
Test loss :  0.2663768073948829
Test Acc :  0.9252717391304348
Patience Counter :  7


  8%|▊         | 4657/54800 [12:50<1:28:03,  9.49it/s] 

Epoch loss :  0.014953106720761845
Epoch accuracy :  0.994867700729927
Test loss :  0.3268871917085641
Test Acc :  0.936141304347826
Patience Counter :  8


  9%|▉         | 4932/54800 [13:33<1:08:33, 12.12it/s] 

Epoch loss :  0.023414645764298428
Epoch accuracy :  0.9931569343065694
Test loss :  0.22467836806469638
Test Acc :  0.9341032608695652
Patience Counter :  9


  9%|▉         | 5205/54800 [14:16<1:09:00, 11.98it/s] 

Epoch loss :  0.018539330935497148
Epoch accuracy :  0.9942974452554745
Test loss :  0.4206996425409275
Test Acc :  0.9341032608695652
Patience Counter :  10


 10%|▉         | 5479/54800 [15:01<1:24:47,  9.69it/s] 

Epoch loss :  0.03223365119284097
Epoch accuracy :  0.9903056569343066
Test loss :  0.24444171780993676
Test Acc :  0.9456521739130435
Patience Counter :  11


 10%|█         | 5754/54800 [15:44<2:06:29,  6.46it/s] 

Epoch loss :  0.01695851800636475
Epoch accuracy :  0.9937271897810219
Test loss :  0.31664699890002934
Test Acc :  0.9375
Patience Counter :  12


 11%|█         | 6028/54800 [16:27<1:34:24,  8.61it/s] 

Epoch loss :  0.017603519805366924
Epoch accuracy :  0.9937271897810219
Test loss :  0.30911913383802725
Test Acc :  0.9408967391304348
Patience Counter :  13


 12%|█▏        | 6302/54800 [17:14<1:07:17, 12.01it/s] 

Epoch loss :  0.013744410059614642
Epoch accuracy :  0.9960082116788321
Test loss :  0.23710170496728414
Test Acc :  0.9524456521739131


 12%|█▏        | 6575/54800 [17:58<1:11:05, 11.31it/s] 

Epoch loss :  0.008464950688816184
Epoch accuracy :  0.9977189781021898
Test loss :  0.2525075035386101
Test Acc :  0.9510869565217391
Patience Counter :  1


 12%|█▏        | 6849/54800 [18:41<1:09:26, 11.51it/s] 

Epoch loss :  0.022204318728177908
Epoch accuracy :  0.9932709854014599
Test loss :  0.19917236643074
Test Acc :  0.9381793478260869
Patience Counter :  2


 13%|█▎        | 7124/54800 [19:25<1:40:54,  7.87it/s] 

Epoch loss :  0.014621064320177079
Epoch accuracy :  0.9955520072992701
Test loss :  0.4033732295672804
Test Acc :  0.936141304347826
Patience Counter :  3


 14%|█▎        | 7398/54800 [20:07<1:04:35, 12.23it/s] 

Epoch loss :  0.013542426440177434
Epoch accuracy :  0.9962363138686131
Test loss :  0.2853455963328167
Test Acc :  0.9510869565217391
Patience Counter :  4


 14%|█▍        | 7671/54800 [20:55<1:08:51, 11.41it/s] 

Epoch loss :  0.025315876318043957
Epoch accuracy :  0.9928147810218978
Test loss :  0.24753458729094785
Test Acc :  0.9239130434782609
Patience Counter :  5


 14%|█▍        | 7946/54800 [21:35<1:24:27,  9.25it/s] 

Epoch loss :  0.018108796275854246
Epoch accuracy :  0.9955520072992701
Test loss :  0.3163058325058639
Test Acc :  0.9510869565217391
Patience Counter :  6


 15%|█▌        | 8220/54800 [22:17<1:07:54, 11.43it/s] 

Epoch loss :  0.012489731312714877
Epoch accuracy :  0.9945255474452555
Test loss :  0.34660201033820276
Test Acc :  0.9402173913043478
Patience Counter :  7


 15%|█▌        | 8493/54800 [23:04<1:18:47,  9.79it/s] 

Epoch loss :  0.017807747160816234
Epoch accuracy :  0.9933850364963503
Test loss :  0.23250077045193632
Test Acc :  0.9415760869565217
Patience Counter :  8


 16%|█▌        | 8768/54800 [23:48<1:56:30,  6.58it/s] 

Epoch loss :  0.0050064904108556435
Epoch accuracy :  0.9982892335766423
Test loss :  0.3233624858349567
Test Acc :  0.9476902173913043
Patience Counter :  9


 16%|█▋        | 9042/54800 [24:33<1:01:51, 12.33it/s] 

Epoch loss :  0.005375374532588577
Epoch accuracy :  0.9988594890510949


 16%|█▋        | 9042/54800 [24:43<1:01:51, 12.33it/s]

Test loss :  0.31593507793782605
Test Acc :  0.9449728260869565
Patience Counter :  10


 17%|█▋        | 9315/54800 [25:16<1:08:38, 11.04it/s] 

Epoch loss :  0.00861041923550162
Epoch accuracy :  0.9965784671532847
Test loss :  0.42226808684015693
Test Acc :  0.9415760869565217
Patience Counter :  11


 17%|█▋        | 9589/54800 [25:55<1:04:25, 11.70it/s] 

Epoch loss :  0.01416221718492673
Epoch accuracy :  0.9963503649635036
Test loss :  0.3575813752164401
Test Acc :  0.9442934782608695
Patience Counter :  12


 18%|█▊        | 9863/54800 [26:39<1:19:56,  9.37it/s] 

Epoch loss :  0.0010823844515977863
Epoch accuracy :  0.9994297445255474
Test loss :  0.4525549452316583
Test Acc :  0.9341032608695652
Patience Counter :  13


 18%|█▊        | 10137/54800 [27:24<1:53:36,  6.55it/s]

Epoch loss :  0.01545940178527939
Epoch accuracy :  0.9946395985401459
Test loss :  0.35898794795629446
Test Acc :  0.9483695652173914
Patience Counter :  14


 19%|█▉        | 10412/54800 [28:10<1:11:06, 10.40it/s] 

Epoch loss :  0.006948394332888901
Epoch accuracy :  0.9979470802919708
Test loss :  0.5845100642384394
Test Acc :  0.8953804347826086
Patience Counter :  15


 19%|█▉        | 10685/54800 [28:54<1:00:20, 12.18it/s] 

Epoch loss :  0.011185460413760485
Epoch accuracy :  0.9970346715328468
Test loss :  0.29006670716253163
Test Acc :  0.9497282608695652
Patience Counter :  16


 20%|██        | 10960/54800 [29:40<1:03:32, 11.50it/s] 

Epoch loss :  0.0008522529194346921
Epoch accuracy :  0.9998859489051095
Test loss :  0.38571234916605096
Test Acc :  0.9551630434782609


 20%|██        | 11233/54800 [30:22<1:01:30, 11.80it/s] 

Epoch loss :  4.225515767804982e-05
Epoch accuracy :  1.0
Test loss :  0.4227496343482926
Test Acc :  0.9538043478260869
Patience Counter :  1


 21%|██        | 11507/54800 [31:08<2:04:44,  5.78it/s] 

Epoch loss :  2.070278633409129e-06
Epoch accuracy :  1.0
Test loss :  0.4489888282065299
Test Acc :  0.9538043478260869
Patience Counter :  2


 22%|██▏       | 11782/54800 [31:56<1:23:59,  8.54it/s] 

Epoch loss :  1.1798150607702718e-06
Epoch accuracy :  1.0
Test loss :  0.48562962520786046
Test Acc :  0.953125
Patience Counter :  3


 22%|██▏       | 12055/54800 [32:43<59:23, 12.00it/s]   

Epoch loss :  7.337121537606623e-07
Epoch accuracy :  1.0
Test loss :  0.5028589894089813
Test Acc :  0.953125
Patience Counter :  4


 22%|██▎       | 12330/54800 [33:24<1:21:47,  8.65it/s] 

Epoch loss :  4.560835722779064e-07
Epoch accuracy :  1.0


 22%|██▎       | 12330/54800 [33:34<1:21:47,  8.65it/s]

Test loss :  0.5228779285759125
Test Acc :  0.9524456521739131
Patience Counter :  5


 23%|██▎       | 12604/54800 [34:12<1:20:43,  8.71it/s] 

Epoch loss :  2.9064586822600625e-07
Epoch accuracy :  1.0
Test loss :  0.5417103605719065
Test Acc :  0.9524456521739131
Patience Counter :  6


 23%|██▎       | 12877/54800 [34:57<57:12, 12.21it/s]   

Epoch loss :  1.9434993880258553e-07
Epoch accuracy :  1.0
Test loss :  0.5415610074167782
Test Acc :  0.9538043478260869
Patience Counter :  7


 24%|██▍       | 13152/54800 [35:42<1:07:18, 10.31it/s] 

Epoch loss :  1.3374130346821635e-07
Epoch accuracy :  1.0
Test loss :  0.569241383085393
Test Acc :  0.953125
Patience Counter :  8


 24%|██▍       | 13425/54800 [36:26<1:11:17,  9.67it/s] 

Epoch loss :  1.033686813752505e-07
Epoch accuracy :  1.0
Test loss :  0.5747915632874802
Test Acc :  0.9538043478260869
Patience Counter :  9


 25%|██▌       | 13700/54800 [37:08<1:00:06, 11.40it/s] 

Epoch loss :  7.81759184645788e-08
Epoch accuracy :  1.0
Test loss :  0.5701229617086974
Test Acc :  0.9565217391304348


 25%|██▌       | 13973/54800 [37:53<1:27:12,  7.80it/s] 

Epoch loss :  6.039269319104587e-08
Epoch accuracy :  1.0
Test loss :  0.5761270121515755
Test Acc :  0.9551630434782609
Patience Counter :  1


 26%|██▌       | 14247/54800 [38:37<1:11:14,  9.49it/s] 

Epoch loss :  4.739516835280327e-08
Epoch accuracy :  1.0
Test loss :  0.6010882644897365
Test Acc :  0.9558423913043478
Patience Counter :  2


 26%|██▋       | 14522/54800 [39:23<1:15:11,  8.93it/s] 

Epoch loss :  3.7891716554382285e-08
Epoch accuracy :  1.0
Test loss :  0.6064691877952775
Test Acc :  0.9572010869565217


 27%|██▋       | 14795/54800 [40:10<55:27, 12.02it/s]   

Epoch loss :  3.014208957126427e-08
Epoch accuracy :  1.0
Test loss :  0.6064735019632385
Test Acc :  0.9565217391304348
Patience Counter :  1


 28%|██▊       | 15070/54800 [40:56<1:02:24, 10.61it/s] 

Epoch loss :  2.4703745764349107e-08
Epoch accuracy :  1.0
Test loss :  0.6002865047279616
Test Acc :  0.9572010869565217
Patience Counter :  2


 28%|██▊       | 15344/54800 [41:36<56:25, 11.65it/s]   

Epoch loss :  2.0421049305728992e-08
Epoch accuracy :  1.0
Test loss :  0.636841701659726
Test Acc :  0.9558423913043478
Patience Counter :  3


 28%|██▊       | 15617/54800 [42:20<1:01:06, 10.69it/s] 

Epoch loss :  1.6994887894377346e-08
Epoch accuracy :  1.0
Test loss :  0.6417185251830282
Test Acc :  0.9565217391304348
Patience Counter :  4


 29%|██▉       | 15891/54800 [43:04<53:58, 12.01it/s]   

Epoch loss :  1.41397574022502e-08
Epoch accuracy :  1.0
Test loss :  0.6438470531836316
Test Acc :  0.9565217391304348
Patience Counter :  5


 29%|██▉       | 16165/54800 [43:46<57:04, 11.28it/s]   

Epoch loss :  1.2032392060004109e-08
Epoch accuracy :  1.0
Test loss :  0.6581960999755033
Test Acc :  0.9565217391304348
Patience Counter :  6


 30%|██▉       | 16439/54800 [44:25<52:14, 12.24it/s]   

Epoch loss :  1.02241378186092e-08
Epoch accuracy :  1.0


 30%|███       | 16440/54800 [44:35<52:14, 12.24it/s]

Test loss :  0.6413635897902363
Test Acc :  0.9565217391304348
Patience Counter :  7


 30%|███       | 16713/54800 [45:08<57:07, 11.11it/s]   

Epoch loss :  8.687799397974341e-09
Epoch accuracy :  1.0
Test loss :  0.6295438818785379
Test Acc :  0.9565217391304348
Patience Counter :  8


 31%|███       | 16987/54800 [45:49<54:52, 11.49it/s]   

Epoch loss :  7.640913510182096e-09
Epoch accuracy :  1.0
Test loss :  0.6653522965882802
Test Acc :  0.9565217391304348
Patience Counter :  9


 32%|███▏      | 17262/54800 [46:38<55:24, 11.29it/s]   

Epoch loss :  6.63481633705036e-09
Epoch accuracy :  1.0
Test loss :  0.6582775160100682
Test Acc :  0.9572010869565217
Patience Counter :  10


 32%|███▏      | 17536/54800 [47:22<51:55, 11.96it/s]   

Epoch loss :  5.914232175367547e-09
Epoch accuracy :  1.0
Test loss :  0.6724817767598615
Test Acc :  0.9578804347826086


 32%|███▏      | 17809/54800 [48:10<1:45:09,  5.86it/s] 

Epoch loss :  5.22084051011449e-09
Epoch accuracy :  1.0
Test loss :  0.6897362028324862
Test Acc :  0.9558423913043478
Patience Counter :  1


 33%|███▎      | 18084/54800 [48:54<1:36:46,  6.32it/s] 

Epoch loss :  4.677003156213299e-09
Epoch accuracy :  1.0
Test loss :  0.6966661585020483
Test Acc :  0.9551630434782609
Patience Counter :  2


 34%|███▎      | 18358/54800 [49:38<51:02, 11.90it/s]   

Epoch loss :  4.037994297409739e-09
Epoch accuracy :  1.0
Test loss :  0.6889785846180051
Test Acc :  0.9565217391304348
Patience Counter :  3


 34%|███▍      | 18631/54800 [50:20<52:20, 11.52it/s]   

Epoch loss :  3.643712391015654e-09
Epoch accuracy :  1.0
Test loss :  0.6788266553095711
Test Acc :  0.9565217391304348
Patience Counter :  4


 34%|███▍      | 18906/54800 [51:05<1:24:30,  7.08it/s] 

Epoch loss :  3.2630261809664617e-09
Epoch accuracy :  1.0
Test loss :  0.7096764302381386
Test Acc :  0.9551630434782609
Patience Counter :  5


 35%|███▍      | 19179/54800 [51:54<1:11:56,  8.25it/s] 

Epoch loss :  2.8959359095663117e-09
Epoch accuracy :  1.0
Test loss :  0.7135303526321025
Test Acc :  0.9558423913043478
Patience Counter :  6


 35%|███▌      | 19453/54800 [52:39<54:46, 10.76it/s]   

Epoch loss :  2.7055927312021655e-09
Epoch accuracy :  1.0
Test loss :  0.7177861146929227
Test Acc :  0.9558423913043478
Patience Counter :  7


 36%|███▌      | 19727/54800 [53:26<54:33, 10.71it/s]   

Epoch loss :  2.4336740011715636e-09
Epoch accuracy :  1.0
Test loss :  0.6998269346733992
Test Acc :  0.9572010869565217
Patience Counter :  8


 36%|███▋      | 20002/54800 [54:28<2:06:24,  4.59it/s] 

Epoch loss :  2.2161389738726958e-09
Epoch accuracy :  1.0
Test loss :  0.7058768765339561
Test Acc :  0.9565217391304348
Patience Counter :  9


 37%|███▋      | 20276/54800 [55:49<2:15:45,  4.24it/s] 

Epoch loss :  1.9986039060547397e-09
Epoch accuracy :  1.0
Test loss :  0.7294341675768877
Test Acc :  0.9558423913043478
Patience Counter :  10


 38%|███▊      | 20550/54800 [56:46<1:28:07,  6.48it/s] 

Epoch loss :  1.7946647712131526e-09
Epoch accuracy :  1.0
Test loss :  0.7255873916992246
Test Acc :  0.9565217391304348
Patience Counter :  11


 38%|███▊      | 20824/54800 [57:46<1:31:49,  6.17it/s] 

Epoch loss :  1.6451094760578188e-09
Epoch accuracy :  1.0
Test loss :  0.7198062878087018
Test Acc :  0.9572010869565217
Patience Counter :  12


 38%|███▊      | 21097/54800 [58:48<49:44, 11.29it/s]   

Epoch loss :  1.590725695529435e-09
Epoch accuracy :  1.0
Test loss :  0.7383674976768648
Test Acc :  0.9558423913043478
Patience Counter :  13


 39%|███▉      | 21371/54800 [59:42<1:19:37,  7.00it/s] 

Epoch loss :  1.4819581490595394e-09
Epoch accuracy :  1.0
Test loss :  0.7341735178965686
Test Acc :  0.9572010869565217
Patience Counter :  14


 40%|███▉      | 21646/54800 [1:00:35<1:01:38,  8.96it/s]

Epoch loss :  1.400382498526508e-09
Epoch accuracy :  1.0
Test loss :  0.6845195761858176
Test Acc :  0.9585597826086957


 40%|███▉      | 21918/54800 [1:00:52<12:43, 43.07it/s]  

Epoch loss :  1.2780189890961176e-09
Epoch accuracy :  1.0
Test loss :  0.7389559899240481
Test Acc :  0.9578804347826086
Patience Counter :  1


 40%|████      | 22193/54800 [1:01:27<1:14:46,  7.27it/s]

Epoch loss :  1.1556554893903084e-09
Epoch accuracy :  1.0
Test loss :  0.7298594375358987
Test Acc :  0.9585597826086957
Patience Counter :  2


 41%|████      | 22467/54800 [1:01:47<14:15, 37.81it/s]  

Epoch loss :  1.060483855637238e-09
Epoch accuracy :  1.0
Test loss :  0.7258903240982074
Test Acc :  0.9578804347826086
Patience Counter :  3


 42%|████▏     | 22742/54800 [1:02:00<13:32, 39.46it/s]  

Epoch loss :  1.0061000921268715e-09
Epoch accuracy :  1.0
Test loss :  0.6929361770201111
Test Acc :  0.9599184782608695


 42%|████▏     | 23015/54800 [1:02:55<1:28:47,  5.97it/s]

Epoch loss :  9.653122478163842e-10
Epoch accuracy :  1.0
Test loss :  0.7503113179989696
Test Acc :  0.9572010869565217
Patience Counter :  1


 42%|████▏     | 23285/54800 [1:03:09<15:20, 34.24it/s]  

Epoch loss :  9.109284697191458e-10
Epoch accuracy :  1.0
Test loss :  0.7528021848636696
Test Acc :  0.9578804347826086
Patience Counter :  2


 43%|████▎     | 23559/54800 [1:03:53<1:02:42,  8.30it/s]

Epoch loss :  8.565446989153434e-10
Epoch accuracy :  1.0
Test loss :  0.7514046790792627
Test Acc :  0.9578804347826086
Patience Counter :  3


 44%|████▎     | 23838/54800 [1:04:06<13:27, 38.35it/s]  

Epoch loss :  8.293528175653509e-10
Epoch accuracy :  1.0
Test loss :  0.7520462646714572
Test Acc :  0.9578804347826086
Patience Counter :  4


 44%|████▍     | 24110/54800 [1:04:20<16:34, 30.87it/s]  

Epoch loss :  8.293528118926785e-10
Epoch accuracy :  1.0
Test loss :  0.7157643658481645
Test Acc :  0.9592391304347826
Patience Counter :  5


 44%|████▍     | 24386/54800 [1:05:09<1:14:03,  6.84it/s]

Epoch loss :  7.477771257028494e-10
Epoch accuracy :  1.0


In [None]:
best_epoch

29

In [None]:
patience_counter

20

In [None]:
evaluate(query_encoder,test_dataloader)

Test loss :  1.0130133652836895
Test Acc :  0.9123641304347826


(1.0130133652836895, 0.9123641304347826)

 49%|████▉     | 6713/13700 [10:20<02:43, 42.69it/s]