# Goal:

In this assignment, you will implement the self-supervised contrastive learning algorithm, [SimCLR](https://arxiv.org/abs/2002.05709), using PyTorch. You will use the STL-10 dataset for this assignment.

You need to compete the `Net` class definition, the `SimCLRDataset` dataset class definition and the SimCLR loss in the `Trainer` class. You need to run the training loop, save the best training model and evaluate using the `linear probe` classfication task. Since we don't have enough GPU resources and contrastive learning algorithm like SimCLR usually needs around `1000` epoches to train (we only have `70` epoches), you may not get the best performance. Thus, as for the performance side, as long as you see the loss is decreasing (to around 7.4 at `70` epoch) and the accuracy is increasing, you are good to go.

Grade:

- **Fill the Net class definition (5 points).**
- **Fill the SimCLRDataset dataset class definition (10 points).**
- **Fill the SimCLR loss in the Trainer class (20 points).**
- **Record the training loss within 70 epochs, the lower the better (5 points).**
- **Record the linear probe accuracy, the higher the better (5 points).**
- **Write a report including:**
  - **How you select data augmentation (transform) in the transform pool.**
  - **How you implement the SimCLR loss and explain why your SimCLR loss is computationally efficient and equivalent to the loss function in the paper.**
  - **Include the training loss curve and the downstream accuracy (15 points). Note that the logging logic is not provided, please implement them before you start training.**
---
Please DO NOT change the config provided. Only change the given code if you are confident that the change is necessary. It is recommended that you **use CPU session to debug** when GPU is not necessary since Colab only gives 12 hrs of free GPU access at a time. If you use up the GPU resource, you may consider to use Kaggle GPU resource. Thank you and good luck!

# Self-supervised learning: SimCLR

Self-supervised learning

1.   Design an auxiliary task.
2.   Train the base network on the auxiliary task.
3.   Evaluate on the down-stream task: Train a new decoder based on the trained encoder.

Specifically, as one of the most successful self-supervised learning algorithm, SimCLR, a contrastive learning algorithm is what we focus today. Below, we are going to implement SimCLR as an example of self-supervised learning.


<img src="https://camo.githubusercontent.com/35af3432fbe91c56a934b5ee58931b4848ab35043830c9dd6f08fa41e6eadbe7/68747470733a2f2f312e62702e626c6f6773706f742e636f6d2f2d2d764834504b704539596f2f586f3461324259657276492f414141414141414146704d2f766146447750584f79416f6b4143385868383532447a4f67457332324e68625877434c63424741735948512f73313630302f696d616765342e676966" width="650" height="650">

In [1]:
# Config
# Since, we are using jupyter notebook, we use easydict to micic argparse. Feel free to use other format of config
from easydict import EasyDict
import torch.nn as nn
from tqdm import tqdm
import torch

config = {
    'dataset_name': 'stl10',
    'workers': 1,
    'epochs': 70,
    'batch_size': 2560,
    'lr': 0.0003,
    'weight_decay': 1e-4,
    'seed': 4242,
    'fp16_precision': True,
    'out_dim': 128,
    'temperature': 0.5,
    'n_views': 2,
    'device': "cuda" if torch.cuda.is_available() else "cpu",
}
args = EasyDict(config)

We are going to use [STL-10 dataset](https://cs.stanford.edu/~acoates/stl10/).

<img src="https://cs.stanford.edu/~acoates/stl10/images.png" width="450" height="450">

Overview

*   10 classes: airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck.
*   Images are **96x96** pixels, color.
*   500 training images (10 pre-defined folds), 800 test images per class.
*   100000 unlabeled images for unsupervised learning. These examples are extracted from a similar but broader distribution of images. For instance, it contains other types of animals (bears, rabbits, etc.) and vehicles (trains, buses, etc.) in addition to the ones in the labeled set.
*   Images were acquired from labeled examples on ImageNet.


## Preparation

Define a ResNet-18 and an additional MLP layer as the model training in the auxiliary task.

In [2]:
import random
import numpy as np

import torchvision.models as models

torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.basemodel = models.resnet18(pretrained=False, num_classes=args.out_dim)
        self.fc_in_features = self.basemodel.fc.in_features
        self.backup_fc = None
        # ToDO: define an MLP layer to insert between the last layer and the rest of the model
        self.basemodel.fc = nn.Sequential(
          nn.Linear(self.fc_in_features, self.fc_in_features),
          nn.ReLU(),
          self.basemodel.fc,
        )

    def forward(self, x):
        # ToDo: implement the forward logic
        output = self.basemodel(x)
        return output


    def linear_probe(self):
        self.freeze_basemodel_encoder()
        self.backup_fc = self.basemodel.fc  # Backup the last Linear layer
        # ToDo: implement the linear probe for your downstream task. A linear prob is just a linear layer (not MLP, no activation layer included) after the learned encoder.
        self.basemodel.fc = nn.Linear(self.fc_in_features, 10)


    def restore_backbone(self):
        self.basemodel.fc = self.backup_fc
        self.backup_fc = None

    def freeze_basemodel_encoder(self):
        # do not freeze the self.basemodel.fc weights
        for name, param in self.basemodel.named_parameters():
            if 'fc' not in name:
                param.requires_grad = False

# Step 1: Design the auxiliary task.
## construct the dataset

In [3]:
from torchvision import transforms, datasets

class View_sampler(object):
    """This class randomly sample two transforms from the list of transforms for the SimCLR to use. It is used in the SimCLRDataset.get_dataset."""

    def __init__(self, transforms, n_views=2):
        self.transforms = transforms
        self.n_views = n_views

    def __call__(self, x):
        return [self.transforms(x) for i in range(self.n_views)]


class SimCLRDataset:
    def __init__(self, root_folder="./datasets"):
        self.root_folder = root_folder

    @staticmethod
    def transforms_pool():
        # ToDo
        new_data = transforms.Compose([
            transforms.RandomResizedCrop(95),
            transforms.RandomHorizontalFlip(),
            transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0)),
            transforms.ColorJitter(brightness=0.35, contrast=0.35, saturation=0.35, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.3, 0.3, 0.3])
        ])
        return new_data

    def get_dataset(self):
        dataset_fn = lambda: datasets.STL10(self.root_folder, split='unlabeled', transform=View_sampler(self.transforms_pool(), 2), download=True)
        return dataset_fn()

## Define dataloader, optimizer and scheduler

What is a scheduler?

A scheduler helps in optimizing the convergence, avoiding local minima, and potentially improving the model's performance on the task at hand. The learning rate is one of the most important hyperparameters for training neural networks, and finding an appropriate learning rate schedule can be crucial for your model's success.

<img src="https://miro.medium.com/v2/resize:fit:4800/format:webp/1*qe6nYlH8zsmUdScyHMhRCQ.png" width="1200" height="450">

Read more here: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

In [4]:
from torch.utils.data import DataLoader
import torch.optim as optim

model = Net()
dataset = SimCLRDataset()
train_dataset = dataset.get_dataset()
# ToDo, define dataloader based on the train_dataset with drop_last=True
dataloader = DataLoader(
    train_dataset, 
    batch_size=args.batch_size, 
    shuffle=True, 
    drop_last=True,
    num_workers=args.workers
)
# ToDo, define an optimizer with args.lr as the learning rate and args.weight_decay as the weight_decay
optimizer = optim.Adam(
    model.parameters(), 
    lr=args.lr, 
    weight_decay=args.weight_decay
)
# ToDo, define an lr_scheduler CosineAnnealingLR for the optimizer
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=args.epochs,
    eta_min = 1e-4 
)



Files already downloaded and verified


# Define the trainer


Automatic Mixed Precision (AMP) is a technique that aims to improve the speed and efficiency of training deep neural networks by leveraging mixed-precision training.

**Introduction to AMP**

AMP allows neural network training to use both single-precision (FP32) and half-precision (FP16) floating point arithmetic simultaneously. The main idea behind AMP is to perform certain operations in FP16 to exploit the faster arithmetic and reduced memory usage of lower-precision computing, while maintaining the critical parts of the computation in FP32 to ensure model accuracy and stability.

**Why We Cannot Always Use FP16**



*   **Numerical Stability**: FP16 has a smaller dynamic range and lower precision compared to FP32. This limitation can lead to numerical instability, such as underflows and overflows, particularly during operations that involve small gradient values or require high numerical precision. This can adversely affect the convergence and accuracy of the trained model.
*   **Selective Precision Requirements**: Certain operations and layers within neural networks are more sensitive to precision than others. For example, weight updates in optimizers might require FP32 to maintain accuracy over time. AMP strategies, therefore, involve selectively applying FP16 to parts of the computation where it can be beneficial without undermining the overall training process.

Below, we present how to include AMP logic in standard torch training procedure.

Before including AMP:
```python
for batch in data_loader:
    # Forward pass
    inputs, targets = batch
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)

    # Backward pass and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
```

After including AMP:
```python
from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

for batch in data_loader:
    inputs, targets = batch[0].cuda(), batch[1].cuda()

    # Forward pass
    with autocast():
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)

    # Backward pass and optimize
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
```


Read more here: https://pytorch.org/docs/stable/amp.html

## Implement loss function in the Trainer
The algorithm of SimCLR Loss function is as follows:
$$
\begin{aligned}
&\text{for all } i \in \{1, \ldots, 2N\} \text{ and } j \in \{1, \ldots, 2N\} \text{ do} \\
&\quad s_{i,j} = \frac{z_i^\top z_j}{\|z_i\|\|z_j\|} \quad \text{# pairwise similarity} \\
&\text{end for} \\
&\text{define } \ell(i,j) \text{ as } \ell(i,j) = -\log \left( \frac{\exp(s_{i,j} / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq j]} \exp(s_{i,k} / \tau)} \right) \\
&\mathcal{L} = \frac{1}{2N} \sum_{k=1}^{N} \left[\ell(2k-1, 2k) + \ell(2k, 2k-1)\right] \\
&\text{update networks } f \text{ and } g \text{ to minimize } \mathcal{L}
\end{aligned}
$$

Please fill the blanks in the loss function below. Hint: implement mask to avoid including self-self similarity, and postive pairs and negative pairs.

In [5]:
import torch
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
# from torch.utils.tensorboard import SummaryWriter
# writer = SummaryWriter(log_dir='./SimCLR_logs/')

class Trainer():
    def __init__(self, *args, **kwargs):
        self.args = kwargs['args']
        self.model = kwargs['model'].to(self.args.device)
        self.optimizer = kwargs['optimizer']
        self.scheduler = kwargs['scheduler']
        self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)


    def loss(self, features):
        # input features isrch tensor with shape of (2*batch_size, out_dim)
        # The positive pairs are (features[i] and features[i+batch_size]) for all i
        # TODO: implement the loss function
        half_batch_size = features.shape[0] // 2
        
        # Generating labels for matching pairs
        full_labels = torch.arange(half_batch_size, device=self.args.device)
        concatenated_labels = torch.cat((full_labels, full_labels), dim=0)

        # Normalize features to unit length
        normalized_features = F.normalize(features, dim=1)

        # Compute similarity matrix
        sim_matrix = torch.mm(normalized_features, normalized_features.T)

        # Mask to filter out self-positives from the diagonal
        diag_mask = torch.eye(2 * half_batch_size, dtype=torch.bool, device=self.args.device)

        # Compute logit scores for positive and negative pairs
        positive_indices = torch.arange(half_batch_size, device=self.args.device)
        positive_sim_scores = sim_matrix[positive_indices, positive_indices + half_batch_size]
        negative_sim_scores = sim_matrix[~diag_mask].view(half_batch_size, -1)

        # Stacking positives and negatives and normalizing by temperature
        stacked_logits = torch.cat([positive_sim_scores.unsqueeze(1), negative_sim_scores], dim=1) / self.args.temperature

        # Ground truth labels: zeros since the positives are at index 0
        target_labels = torch.zeros(half_batch_size, dtype=torch.long, device=self.args.device)

        # Cross entropy loss between scores and true labels
        contrastive_loss = F.cross_entropy(stacked_logits, target_labels).mean()
        
        return contrastive_loss
        
    def train(self, dataloader):
        # implement GradScaler if AMP
        best_loss = 1e4
        scaler = GradScaler(enabled=self.args.fp16_precision)
        for epoch in range(self.args.epochs):
            for images, _ in tqdm(dataloader):
                images = torch.cat(images, dim=0)
                images = images.to(self.args.device)

                with autocast(enabled=self.args.fp16_precision):
                    features = self.model(images)
                    loss = self.loss(features)

                # features = self.model(images)
                # loss = self.loss(features)
                self.optimizer.zero_grad()

                scaler.scale(loss).backward()
                scaler.step(self.optimizer)
                scaler.update()

                # loss.backward()
                # self.optimizer.step()
                self.scheduler.step()

            # warmup for the first 10 epochs
            if epoch >= 10:
                self.scheduler.step()
            if epoch % 10 == 0 and epoch != 0:
                self.save_model(self.model, f"model_{epoch}.pth")
            # save the lowest loss model
            # feel free to implement your own logic to save the best model
            if loss < best_loss:
                best_loss = loss
                self.save_model(self.model, f"best_model.pth")
            print(f"Epoch {epoch}, Loss {loss.item()}")
            writer.add_scalar("SSL_Loss/Train", loss, epoch)
            writer.add_scalar("SSL_AVG_Loss/Train", loss/len(dataloader), epoch)
        return self.model

    def save_model(self, model, path):
        torch.save(model.state_dict(), path)

# Step 2: Train the base network on the auxiliary task for 70 epoch and save the best model you have for evaluation.

Check if the training loss drops over time and try to capture other possible bug using logging tools. Each epoch should take around 7 minutes. The loss should be expected to around 7.4.

In [6]:
trainer = Trainer(args=args, model=model, optimizer=optimizer, scheduler=lr_scheduler)
trainer.train(dataloader)

  0%|          | 0/39 [00:00<?, ?it/s]

100%|██████████| 39/39 [06:28<00:00,  9.96s/it]


Epoch 0, Loss 8.299324035644531


100%|██████████| 39/39 [06:29<00:00,  9.98s/it]


Epoch 1, Loss 8.16845703125


100%|██████████| 39/39 [06:29<00:00,  9.99s/it]


Epoch 2, Loss 8.132220268249512


100%|██████████| 39/39 [06:28<00:00,  9.96s/it]


Epoch 3, Loss 8.094571113586426


100%|██████████| 39/39 [06:29<00:00, 10.00s/it]


Epoch 4, Loss 8.03593921661377


100%|██████████| 39/39 [06:29<00:00, 10.00s/it]


Epoch 5, Loss 8.032560348510742


100%|██████████| 39/39 [06:30<00:00, 10.01s/it]


Epoch 6, Loss 8.024721145629883


100%|██████████| 39/39 [06:28<00:00,  9.97s/it]


Epoch 7, Loss 7.998046875


100%|██████████| 39/39 [06:29<00:00,  9.98s/it]


Epoch 8, Loss 7.969766139984131


100%|██████████| 39/39 [06:28<00:00,  9.97s/it]


Epoch 9, Loss 7.958639621734619


100%|██████████| 39/39 [06:29<00:00,  9.99s/it]


Epoch 10, Loss 7.926672458648682


100%|██████████| 39/39 [06:29<00:00, 10.00s/it]


Epoch 11, Loss 7.928199768066406


100%|██████████| 39/39 [06:29<00:00,  9.99s/it]


Epoch 12, Loss 7.917078971862793


100%|██████████| 39/39 [06:29<00:00,  9.99s/it]


Epoch 13, Loss 7.919346809387207


100%|██████████| 39/39 [06:29<00:00,  9.99s/it]


Epoch 14, Loss 7.918848991394043


100%|██████████| 39/39 [06:28<00:00,  9.95s/it]


Epoch 15, Loss 7.886206150054932


100%|██████████| 39/39 [06:26<00:00,  9.90s/it]


Epoch 16, Loss 7.901492118835449


100%|██████████| 39/39 [06:24<00:00,  9.86s/it]


Epoch 17, Loss 7.884333610534668


100%|██████████| 39/39 [06:25<00:00,  9.88s/it]


Epoch 18, Loss 7.8777265548706055


100%|██████████| 39/39 [06:25<00:00,  9.88s/it]


Epoch 19, Loss 7.8714399337768555


100%|██████████| 39/39 [06:25<00:00,  9.88s/it]


Epoch 20, Loss 7.8748579025268555


100%|██████████| 39/39 [06:24<00:00,  9.86s/it]


Epoch 21, Loss 7.882998466491699


100%|██████████| 39/39 [06:24<00:00,  9.86s/it]


Epoch 22, Loss 7.851782321929932


100%|██████████| 39/39 [06:25<00:00,  9.88s/it]


Epoch 23, Loss 7.861990451812744


100%|██████████| 39/39 [06:24<00:00,  9.85s/it]


Epoch 24, Loss 7.8796796798706055


100%|██████████| 39/39 [06:20<00:00,  9.75s/it]


Epoch 25, Loss 7.856381416320801


100%|██████████| 39/39 [06:16<00:00,  9.64s/it]


Epoch 26, Loss 7.856183052062988


100%|██████████| 39/39 [06:15<00:00,  9.64s/it]


Epoch 27, Loss 7.8542022705078125


100%|██████████| 39/39 [06:15<00:00,  9.63s/it]


Epoch 28, Loss 7.8615312576293945


100%|██████████| 39/39 [06:16<00:00,  9.66s/it]


Epoch 29, Loss 7.832898139953613


100%|██████████| 39/39 [06:16<00:00,  9.66s/it]


Epoch 30, Loss 7.8318071365356445


100%|██████████| 39/39 [06:23<00:00,  9.82s/it]


Epoch 31, Loss 7.840065002441406


100%|██████████| 39/39 [06:29<00:00,  9.99s/it]


Epoch 32, Loss 7.833921909332275


100%|██████████| 39/39 [06:27<00:00,  9.93s/it]


Epoch 33, Loss 7.8255815505981445


100%|██████████| 39/39 [06:21<00:00,  9.79s/it]


Epoch 34, Loss 7.8358306884765625


100%|██████████| 39/39 [06:16<00:00,  9.65s/it]


Epoch 35, Loss 7.8354172706604


100%|██████████| 39/39 [06:16<00:00,  9.65s/it]


Epoch 36, Loss 7.817986965179443


100%|██████████| 39/39 [06:16<00:00,  9.64s/it]


Epoch 37, Loss 7.81381368637085


100%|██████████| 39/39 [06:16<00:00,  9.66s/it]


Epoch 38, Loss 7.8396315574646


100%|██████████| 39/39 [06:16<00:00,  9.67s/it]


Epoch 39, Loss 7.816778659820557


100%|██████████| 39/39 [06:18<00:00,  9.70s/it]


Epoch 40, Loss 7.8061723709106445


100%|██████████| 39/39 [06:30<00:00, 10.01s/it]


Epoch 41, Loss 7.815589904785156


100%|██████████| 39/39 [06:27<00:00,  9.95s/it]


Epoch 42, Loss 7.812707424163818


100%|██████████| 39/39 [06:23<00:00,  9.83s/it]


Epoch 43, Loss 7.787785530090332


100%|██████████| 39/39 [06:19<00:00,  9.74s/it]


Epoch 44, Loss 7.7909040451049805


100%|██████████| 39/39 [06:16<00:00,  9.65s/it]


Epoch 45, Loss 7.820065498352051


100%|██████████| 39/39 [06:15<00:00,  9.64s/it]


Epoch 46, Loss 7.796164035797119


100%|██████████| 39/39 [06:16<00:00,  9.65s/it]


Epoch 47, Loss 7.791668891906738


100%|██████████| 39/39 [06:17<00:00,  9.67s/it]


Epoch 48, Loss 7.808434963226318


100%|██████████| 39/39 [06:15<00:00,  9.64s/it]


Epoch 49, Loss 7.8046464920043945


100%|██████████| 39/39 [06:24<00:00,  9.86s/it]


Epoch 50, Loss 7.799729824066162


100%|██████████| 39/39 [06:29<00:00,  9.98s/it]


Epoch 51, Loss 7.7912750244140625


100%|██████████| 39/39 [06:29<00:00,  9.98s/it]


Epoch 52, Loss 7.797071933746338


100%|██████████| 39/39 [06:27<00:00,  9.94s/it]


Epoch 53, Loss 7.788862705230713


100%|██████████| 39/39 [06:25<00:00,  9.89s/it]


Epoch 54, Loss 7.774940490722656


100%|██████████| 39/39 [06:29<00:00,  9.98s/it]


Epoch 55, Loss 7.784135341644287


100%|██████████| 39/39 [06:29<00:00,  9.99s/it]


Epoch 56, Loss 7.796786308288574


100%|██████████| 39/39 [06:29<00:00,  9.99s/it]


Epoch 57, Loss 7.769906520843506


100%|██████████| 39/39 [06:29<00:00,  9.98s/it]


Epoch 58, Loss 7.778923034667969


100%|██████████| 39/39 [06:29<00:00, 10.00s/it]


Epoch 59, Loss 7.795539855957031


100%|██████████| 39/39 [06:29<00:00,  9.98s/it]


Epoch 60, Loss 7.771317958831787


100%|██████████| 39/39 [06:28<00:00,  9.96s/it]


Epoch 61, Loss 7.765118598937988


100%|██████████| 39/39 [06:28<00:00,  9.96s/it]


Epoch 62, Loss 7.771035671234131


100%|██████████| 39/39 [06:25<00:00,  9.87s/it]


Epoch 63, Loss 7.80230712890625


100%|██████████| 39/39 [06:28<00:00,  9.96s/it]


Epoch 64, Loss 7.756822109222412


100%|██████████| 39/39 [06:25<00:00,  9.90s/it]


Epoch 65, Loss 7.780392646789551


100%|██████████| 39/39 [06:19<00:00,  9.72s/it]


Epoch 66, Loss 7.789923191070557


100%|██████████| 39/39 [06:27<00:00,  9.93s/it]


Epoch 67, Loss 7.77004861831665


100%|██████████| 39/39 [06:26<00:00,  9.92s/it]


Epoch 68, Loss 7.746987819671631


100%|██████████| 39/39 [06:24<00:00,  9.86s/it]

Epoch 69, Loss 7.783017158508301





Net(
  (basemodel): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runni

# Step 3: Evaluate on the down-stream task: Train a new MLP decoder based on the trained encoder.

This finetune process should be way faster than the previous one. The expected Top-1 accuracy should be around 57% and the Top-5 accuracy should be around 97%. Getting this results is normal because linear prob is just a projection layer usually recognized as not having representation ability. To achieve the performance mentioned in the paper, we need larger dataset and more powerful GPU and longer time (about 1000 epoches during the pretraining stage).

In [7]:
class linear_prob_Trainer:
    def __init__(self, *args, **kwargs):
        self.args = kwargs["args"]
        self.model = kwargs["model"].to(self.args.device)
        self.optimizer = kwargs["optimizer"]
        self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
        self.train_dataset = datasets.STL10(
            "./data", split="train", download=True, transform=transforms.ToTensor()
        )

        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.batch_size,
            num_workers=1,
            drop_last=False,
        )

        self.test_dataset = datasets.STL10(
            "./data", split="test", download=True, transform=transforms.ToTensor()
        )

        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.args.batch_size,
            num_workers=1,
            drop_last=False,
        )

    def accuracy(self, output, target, topk=(1,)):
        with torch.no_grad():
            maxk = max(topk)
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []
            for k in topk:
                correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
                res.append(correct_k.mul_(100.0 / batch_size))
            return res

    def train(self, dataloader):
        for epoch in range(100):
            top1_train_accuracy = 0
            for images, labels in tqdm(dataloader):
                images, labels = images.to(self.args.device), labels.to(
                    self.args.device
                )
                logits = self.model(images)
                loss = self.criterion(logits, labels)
                top1 = self.accuracy(logits, labels, topk=(1,))
                top1_train_accuracy += top1[0]

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                print(
                    f"Epoch: {epoch}, Loss: {loss.item()}",
                    "Top1 Train Accuracy: ",
                    top1_train_accuracy.item() / len(dataloader),
                )
                writer.add_scalar("Linear_Prob_Loss/Train", loss.item(), epoch)
                writer.add_scalar("Top1_ACC/Train", top1_train_accuracy.item() / len(dataloader), epoch)
        return self.model

    def test(self, dataloader):
        with torch.no_grad():
            model.eval()
            top1_test_accuracy = 0
            top5_test_accuracy = 0
            for images, labels in tqdm(dataloader):
                images, labels = images.to(self.args.device), labels.to(
                    self.args.device
                )
                logits = self.model(images)
                top1 = self.accuracy(logits, labels, topk=(1,))
                top1_test_accuracy += top1[0]
                top5 = self.accuracy(logits, labels, topk=(5,))
                top5_test_accuracy += top5[0]
            print("Top1 Test Accuracy: ", top1_test_accuracy.item() / len(dataloader))
            print("Top5 Test Accuracy: ", top5_test_accuracy.item() / len(dataloader))
            writer.add_scalar("Top1_ACC/Test", top1_test_accuracy / len(dataloader.dataset), 0)
            writer.add_scalar("Top5_ACC/Test", top5_test_accuracy / len(dataloader.dataset), 0)
            return
# load your model here if you need to resume
# ...
model.linear_probe()
linear_probe_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
linear_prob_trainer = linear_prob_Trainer(args=args, model=model, optimizer=linear_probe_optimizer)
model = linear_prob_trainer.train(linear_prob_trainer.train_loader)
linear_prob_trainer.test(linear_prob_trainer.test_loader)

Files already downloaded and verified
Files already downloaded and verified


 50%|█████     | 1/2 [00:01<00:01,  1.16s/it]

Epoch: 0, Loss: 2.400590181350708 Top1 Train Accuracy:  4.82421875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 0, Loss: 2.35282039642334 Top1 Train Accuracy:  10.643890380859375


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 1, Loss: 2.307852268218994 Top1 Train Accuracy:  6.8359375


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 1, Loss: 2.263091564178467 Top1 Train Accuracy:  15.5859375


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 2, Loss: 2.223385810852051 Top1 Train Accuracy:  9.9609375


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 2, Loss: 2.181124687194824 Top1 Train Accuracy:  21.33388900756836


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 3, Loss: 2.146129608154297 Top1 Train Accuracy:  12.59765625


100%|██████████| 2/2 [00:01<00:00,  1.05it/s]


Epoch: 3, Loss: 2.1060338020324707 Top1 Train Accuracy:  26.142738342285156


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 4, Loss: 2.0751237869262695 Top1 Train Accuracy:  14.12109375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 4, Loss: 2.0368995666503906 Top1 Train Accuracy:  29.63338851928711


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 5, Loss: 2.0094895362854004 Top1 Train Accuracy:  15.5078125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 5, Loss: 1.9729236364364624 Top1 Train Accuracy:  32.5364990234375


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 6, Loss: 1.9485208988189697 Top1 Train Accuracy:  16.8359375


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 6, Loss: 1.9134865999221802 Top1 Train Accuracy:  35.299049377441406


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 7, Loss: 1.8917003870010376 Top1 Train Accuracy:  18.0078125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 7, Loss: 1.8581334352493286 Top1 Train Accuracy:  37.823387145996094


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 8, Loss: 1.8386627435684204 Top1 Train Accuracy:  19.55078125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 8, Loss: 1.806526780128479 Top1 Train Accuracy:  40.616355895996094


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 9, Loss: 1.7891395092010498 Top1 Train Accuracy:  20.7421875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 9, Loss: 1.7583966255187988 Top1 Train Accuracy:  42.9143180847168


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 10, Loss: 1.7429115772247314 Top1 Train Accuracy:  21.93359375


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 10, Loss: 1.713508129119873 Top1 Train Accuracy:  45.29425048828125


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 11, Loss: 1.6997840404510498 Top1 Train Accuracy:  23.203125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 11, Loss: 1.6716464757919312 Top1 Train Accuracy:  47.7318115234375


 50%|█████     | 1/2 [00:01<00:01,  1.13s/it]

Epoch: 12, Loss: 1.6595735549926758 Top1 Train Accuracy:  24.39453125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 12, Loss: 1.6326121091842651 Top1 Train Accuracy:  49.968299865722656


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 13, Loss: 1.6221048831939697 Top1 Train Accuracy:  25.0


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 13, Loss: 1.5962228775024414 Top1 Train Accuracy:  51.454917907714844


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 14, Loss: 1.5872132778167725 Top1 Train Accuracy:  25.8203125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 14, Loss: 1.5623129606246948 Top1 Train Accuracy:  52.82850646972656


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 15, Loss: 1.5547411441802979 Top1 Train Accuracy:  26.50390625


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 15, Loss: 1.5307289361953735 Top1 Train Accuracy:  53.88095474243164


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 16, Loss: 1.5245355367660522 Top1 Train Accuracy:  26.97265625


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 16, Loss: 1.5013233423233032 Top1 Train Accuracy:  54.63658905029297


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 17, Loss: 1.4964443445205688 Top1 Train Accuracy:  27.48046875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 17, Loss: 1.4739506244659424 Top1 Train Accuracy:  55.51325607299805


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 18, Loss: 1.47031569480896 Top1 Train Accuracy:  27.8125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 18, Loss: 1.448466420173645 Top1 Train Accuracy:  56.05020523071289


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 19, Loss: 1.446001410484314 Top1 Train Accuracy:  28.10546875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 19, Loss: 1.4247291088104248 Top1 Train Accuracy:  56.589073181152344


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 20, Loss: 1.4233596324920654 Top1 Train Accuracy:  28.33984375


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 20, Loss: 1.402601957321167 Top1 Train Accuracy:  56.782466888427734


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 21, Loss: 1.4022579193115234 Top1 Train Accuracy:  28.5546875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 21, Loss: 1.3819575309753418 Top1 Train Accuracy:  57.18173599243164


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 22, Loss: 1.3825758695602417 Top1 Train Accuracy:  28.7890625


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 22, Loss: 1.3626773357391357 Top1 Train Accuracy:  57.5390625


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 23, Loss: 1.364205002784729 Top1 Train Accuracy:  28.88671875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 23, Loss: 1.3446539640426636 Top1 Train Accuracy:  57.759666442871094


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 24, Loss: 1.347045660018921 Top1 Train Accuracy:  28.984375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 24, Loss: 1.3277878761291504 Top1 Train Accuracy:  57.939292907714844


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 25, Loss: 1.331007719039917 Top1 Train Accuracy:  29.21875


100%|██████████| 2/2 [00:01<00:00,  1.05it/s]


Epoch: 25, Loss: 1.3119889497756958 Top1 Train Accuracy:  58.21465301513672


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 26, Loss: 1.3160042762756348 Top1 Train Accuracy:  29.39453125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 26, Loss: 1.2971724271774292 Top1 Train Accuracy:  58.43141555786133


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 27, Loss: 1.301954746246338 Top1 Train Accuracy:  29.453125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 27, Loss: 1.2832595109939575 Top1 Train Accuracy:  58.57197570800781


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 28, Loss: 1.2887816429138184 Top1 Train Accuracy:  29.53125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 28, Loss: 1.2701777219772339 Top1 Train Accuracy:  58.67059326171875


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 29, Loss: 1.276411771774292 Top1 Train Accuracy:  29.4921875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 29, Loss: 1.2578598260879517 Top1 Train Accuracy:  58.75448226928711


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 30, Loss: 1.2647783756256104 Top1 Train Accuracy:  29.5703125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 30, Loss: 1.2462444305419922 Top1 Train Accuracy:  58.87358856201172


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 31, Loss: 1.253820776939392 Top1 Train Accuracy:  29.6875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 31, Loss: 1.2352763414382935 Top1 Train Accuracy:  59.257171630859375


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 32, Loss: 1.243483304977417 Top1 Train Accuracy:  29.8046875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 32, Loss: 1.2249047756195068 Top1 Train Accuracy:  59.497310638427734


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 33, Loss: 1.2337167263031006 Top1 Train Accuracy:  29.8828125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 33, Loss: 1.2150839567184448 Top1 Train Accuracy:  59.677894592285156


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 34, Loss: 1.2244765758514404 Top1 Train Accuracy:  30.0


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 34, Loss: 1.2057709693908691 Top1 Train Accuracy:  59.91802978515625


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 35, Loss: 1.2157213687896729 Top1 Train Accuracy:  30.078125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 35, Loss: 1.1969274282455444 Top1 Train Accuracy:  60.05763244628906


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 36, Loss: 1.207414984703064 Top1 Train Accuracy:  30.09765625


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 36, Loss: 1.1885178089141846 Top1 Train Accuracy:  60.09765625


 50%|█████     | 1/2 [00:01<00:01,  1.17s/it]

Epoch: 37, Loss: 1.199522852897644 Top1 Train Accuracy:  30.17578125


100%|██████████| 2/2 [00:01<00:00,  1.05it/s]


Epoch: 37, Loss: 1.1805087327957153 Top1 Train Accuracy:  60.257747650146484


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 38, Loss: 1.1920143365859985 Top1 Train Accuracy:  30.1953125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 38, Loss: 1.1728719472885132 Top1 Train Accuracy:  60.33875274658203


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 39, Loss: 1.184861183166504 Top1 Train Accuracy:  30.2734375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 39, Loss: 1.1655806303024292 Top1 Train Accuracy:  60.457862854003906


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 40, Loss: 1.1780377626419067 Top1 Train Accuracy:  30.3125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 40, Loss: 1.1586101055145264 Top1 Train Accuracy:  60.435447692871094


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 41, Loss: 1.1715209484100342 Top1 Train Accuracy:  30.41015625


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 41, Loss: 1.151938557624817 Top1 Train Accuracy:  60.615074157714844


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 42, Loss: 1.1652895212173462 Top1 Train Accuracy:  30.390625


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 42, Loss: 1.145545482635498 Top1 Train Accuracy:  60.7184944152832


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 43, Loss: 1.1593241691589355 Top1 Train Accuracy:  30.44921875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 43, Loss: 1.1394129991531372 Top1 Train Accuracy:  60.94102096557617


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 44, Loss: 1.1536071300506592 Top1 Train Accuracy:  30.48828125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 44, Loss: 1.133523941040039 Top1 Train Accuracy:  61.00057601928711


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 45, Loss: 1.1481223106384277 Top1 Train Accuracy:  30.46875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 45, Loss: 1.1278626918792725 Top1 Train Accuracy:  61.063011169433594


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 46, Loss: 1.1428550481796265 Top1 Train Accuracy:  30.5078125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 46, Loss: 1.1224151849746704 Top1 Train Accuracy:  61.184043884277344


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 47, Loss: 1.137791395187378 Top1 Train Accuracy:  30.60546875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 47, Loss: 1.1171692609786987 Top1 Train Accuracy:  61.32268142700195


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 48, Loss: 1.1329193115234375 Top1 Train Accuracy:  30.6640625


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 48, Loss: 1.1121129989624023 Top1 Train Accuracy:  61.360782623291016


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 49, Loss: 1.1282262802124023 Top1 Train Accuracy:  30.68359375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 49, Loss: 1.1072351932525635 Top1 Train Accuracy:  61.380313873291016


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 50, Loss: 1.123701572418213 Top1 Train Accuracy:  30.703125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 50, Loss: 1.102527141571045 Top1 Train Accuracy:  61.42033767700195


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 51, Loss: 1.1193352937698364 Top1 Train Accuracy:  30.76171875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 51, Loss: 1.0979788303375244 Top1 Train Accuracy:  61.437950134277344


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 52, Loss: 1.115118384361267 Top1 Train Accuracy:  30.703125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 52, Loss: 1.093582272529602 Top1 Train Accuracy:  61.44083023071289


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 53, Loss: 1.1110423803329468 Top1 Train Accuracy:  30.7421875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 53, Loss: 1.0893290042877197 Top1 Train Accuracy:  61.54136657714844


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 54, Loss: 1.1070995330810547 Top1 Train Accuracy:  30.83984375


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 54, Loss: 1.085211992263794 Top1 Train Accuracy:  61.72099304199219


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 55, Loss: 1.1032826900482178 Top1 Train Accuracy:  30.8984375


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 55, Loss: 1.081223726272583 Top1 Train Accuracy:  61.800079345703125


 50%|█████     | 1/2 [00:01<00:01,  1.17s/it]

Epoch: 56, Loss: 1.099585771560669 Top1 Train Accuracy:  30.8984375


100%|██████████| 2/2 [00:01<00:00,  1.05it/s]


Epoch: 56, Loss: 1.0773582458496094 Top1 Train Accuracy:  61.882041931152344


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 57, Loss: 1.0960023403167725 Top1 Train Accuracy:  30.95703125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 57, Loss: 1.0736087560653687 Top1 Train Accuracy:  62.022605895996094


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 58, Loss: 1.0925265550613403 Top1 Train Accuracy:  30.9765625


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 58, Loss: 1.069969892501831 Top1 Train Accuracy:  62.06262969970703


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 59, Loss: 1.0891531705856323 Top1 Train Accuracy:  31.0546875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 59, Loss: 1.066436529159546 Top1 Train Accuracy:  62.16124725341797


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 60, Loss: 1.0858776569366455 Top1 Train Accuracy:  31.11328125


100%|██████████| 2/2 [00:01<00:00,  1.04it/s]


Epoch: 60, Loss: 1.0630033016204834 Top1 Train Accuracy:  62.30180358886719


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 61, Loss: 1.0826942920684814 Top1 Train Accuracy:  31.09375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 61, Loss: 1.059666395187378 Top1 Train Accuracy:  62.26177978515625


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 62, Loss: 1.0795994997024536 Top1 Train Accuracy:  31.1328125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 62, Loss: 1.0564208030700684 Top1 Train Accuracy:  62.341827392578125


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 63, Loss: 1.0765888690948486 Top1 Train Accuracy:  31.171875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 63, Loss: 1.0532629489898682 Top1 Train Accuracy:  62.44236755371094


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 64, Loss: 1.0736582279205322 Top1 Train Accuracy:  31.23046875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 64, Loss: 1.0501888990402222 Top1 Train Accuracy:  62.48046875


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 65, Loss: 1.0708041191101074 Top1 Train Accuracy:  31.2109375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 65, Loss: 1.0471950769424438 Top1 Train Accuracy:  62.542903900146484


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 66, Loss: 1.0680229663848877 Top1 Train Accuracy:  31.2109375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 66, Loss: 1.0442777872085571 Top1 Train Accuracy:  62.56339645385742


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 67, Loss: 1.0653116703033447 Top1 Train Accuracy:  31.23046875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 67, Loss: 1.0414345264434814 Top1 Train Accuracy:  62.603416442871094


 50%|█████     | 1/2 [00:01<00:01,  1.16s/it]

Epoch: 68, Loss: 1.0626671314239502 Top1 Train Accuracy:  31.26953125


100%|██████████| 2/2 [00:01<00:00,  1.05it/s]


Epoch: 68, Loss: 1.0386615991592407 Top1 Train Accuracy:  62.66297149658203


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 69, Loss: 1.060086965560913 Top1 Train Accuracy:  31.328125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 69, Loss: 1.0359562635421753 Top1 Train Accuracy:  62.72156524658203


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 70, Loss: 1.0575674772262573 Top1 Train Accuracy:  31.328125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 70, Loss: 1.0333157777786255 Top1 Train Accuracy:  62.783042907714844


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 71, Loss: 1.0551068782806396 Top1 Train Accuracy:  31.40625


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 71, Loss: 1.030738115310669 Top1 Train Accuracy:  62.90215301513672


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 72, Loss: 1.052702784538269 Top1 Train Accuracy:  31.40625


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 72, Loss: 1.0282201766967773 Top1 Train Accuracy:  62.922645568847656


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 73, Loss: 1.0503523349761963 Top1 Train Accuracy:  31.3671875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 73, Loss: 1.0257600545883179 Top1 Train Accuracy:  62.924564361572266


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 74, Loss: 1.0480539798736572 Top1 Train Accuracy:  31.3671875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 74, Loss: 1.0233550071716309 Top1 Train Accuracy:  63.00653076171875


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 75, Loss: 1.0458056926727295 Top1 Train Accuracy:  31.4453125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 75, Loss: 1.0210034847259521 Top1 Train Accuracy:  63.06416320800781


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 76, Loss: 1.043605089187622 Top1 Train Accuracy:  31.484375


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 76, Loss: 1.0187033414840698 Top1 Train Accuracy:  63.12371826171875


 50%|█████     | 1/2 [00:01<00:01,  1.16s/it]

Epoch: 77, Loss: 1.04145085811615 Top1 Train Accuracy:  31.484375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 77, Loss: 1.0164525508880615 Top1 Train Accuracy:  63.164703369140625


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 78, Loss: 1.039340853691101 Top1 Train Accuracy:  31.50390625


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 78, Loss: 1.014249563217163 Top1 Train Accuracy:  63.20472717285156


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 79, Loss: 1.0372737646102905 Top1 Train Accuracy:  31.62109375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 79, Loss: 1.012092113494873 Top1 Train Accuracy:  63.40388107299805


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 80, Loss: 1.0352481603622437 Top1 Train Accuracy:  31.62109375


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 80, Loss: 1.009979009628296 Top1 Train Accuracy:  63.465354919433594


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 81, Loss: 1.033262014389038 Top1 Train Accuracy:  31.66015625


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 81, Loss: 1.0079084634780884 Top1 Train Accuracy:  63.565895080566406


 50%|█████     | 1/2 [00:01<00:01,  1.16s/it]

Epoch: 82, Loss: 1.031314492225647 Top1 Train Accuracy:  31.6796875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 82, Loss: 1.0058794021606445 Top1 Train Accuracy:  63.626407623291016


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 83, Loss: 1.029403805732727 Top1 Train Accuracy:  31.7578125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 83, Loss: 1.003890037536621 Top1 Train Accuracy:  63.76601028442383


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 84, Loss: 1.0275288820266724 Top1 Train Accuracy:  31.7578125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 84, Loss: 1.001939058303833 Top1 Train Accuracy:  63.827484130859375


 50%|█████     | 1/2 [00:01<00:01,  1.17s/it]

Epoch: 85, Loss: 1.025688648223877 Top1 Train Accuracy:  31.81640625


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 85, Loss: 1.0000251531600952 Top1 Train Accuracy:  63.92706298828125


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 86, Loss: 1.023881435394287 Top1 Train Accuracy:  31.89453125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 86, Loss: 0.9981471300125122 Top1 Train Accuracy:  64.087158203125


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 87, Loss: 1.022106409072876 Top1 Train Accuracy:  31.953125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 87, Loss: 0.9963037967681885 Top1 Train Accuracy:  64.18672943115234


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 88, Loss: 1.0203624963760376 Top1 Train Accuracy:  32.01171875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 88, Loss: 0.9944941997528076 Top1 Train Accuracy:  64.28630828857422


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 89, Loss: 1.0186493396759033 Top1 Train Accuracy:  32.03125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 89, Loss: 0.9927170276641846 Top1 Train Accuracy:  64.30583953857422


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 90, Loss: 1.0169645547866821 Top1 Train Accuracy:  32.05078125


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 90, Loss: 0.9909712076187134 Top1 Train Accuracy:  64.30487823486328


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 91, Loss: 1.0153086185455322 Top1 Train Accuracy:  32.08984375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 91, Loss: 0.9892560839653015 Top1 Train Accuracy:  64.34394073486328


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 92, Loss: 1.0136796236038208 Top1 Train Accuracy:  32.109375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 92, Loss: 0.9875702261924744 Top1 Train Accuracy:  64.38396453857422


 50%|█████     | 1/2 [00:01<00:01,  1.16s/it]

Epoch: 93, Loss: 1.0120774507522583 Top1 Train Accuracy:  32.1484375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 93, Loss: 0.9859131574630737 Top1 Train Accuracy:  64.40253448486328


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 94, Loss: 1.0105006694793701 Top1 Train Accuracy:  32.16796875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 94, Loss: 0.9842836260795593 Top1 Train Accuracy:  64.44255828857422


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 95, Loss: 1.0089490413665771 Top1 Train Accuracy:  32.16796875


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 95, Loss: 0.9826813340187073 Top1 Train Accuracy:  64.44255828857422


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

Epoch: 96, Loss: 1.0074212551116943 Top1 Train Accuracy:  32.1875


100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


Epoch: 96, Loss: 0.9811047911643982 Top1 Train Accuracy:  64.5235595703125


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 97, Loss: 1.0059171915054321 Top1 Train Accuracy:  32.20703125


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 97, Loss: 0.9795536398887634 Top1 Train Accuracy:  64.56358337402344


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 98, Loss: 1.0044358968734741 Top1 Train Accuracy:  32.24609375


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 98, Loss: 0.9780271649360657 Top1 Train Accuracy:  64.60264587402344


 50%|█████     | 1/2 [00:01<00:01,  1.15s/it]

Epoch: 99, Loss: 1.002976655960083 Top1 Train Accuracy:  32.28515625


100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch: 99, Loss: 0.9765240550041199 Top1 Train Accuracy:  64.66220092773438


100%|██████████| 4/4 [00:02<00:00,  1.50it/s]

Top1 Test Accuracy:  62.705078125
Top5 Test Accuracy:  97.744140625



