Mohammadhossein Malekpour - 20269421

# Goal:

In this assignment, you will implement the self-supervised contrastive learning algorithm, [SimCLR](https://arxiv.org/abs/2002.05709), using PyTorch. You will use the STL-10 dataset for this assignment.

You need to compete the `Net` class definition, the `SimCLRDataset` dataset class definition and the SimCLR loss in the `Trainer` class. You need to run the training loop, save the best training model and evaluate using the `linear probe` classfication task. Since we don't have enough GPU resources and contrastive learning algorithm like SimCLR usually needs around `1000` epoches to train (we only have `70` epoches), you may not get the best performance. Thus, as for the performance side, as long as you see the loss is decreasing (to around 7.4 at `70` epoch) and the accuracy is increasing, you are good to go.

Grade:

- **Fill the Net class definition (5 points).**
- **Fill the SimCLRDataset dataset class definition (10 points).**
- **Fill the SimCLR loss in the Trainer class (20 points).**
- **Record the training loss within 70 epochs, the lower the better (5 points).**
- **Record the linear probe accuracy, the higher the better (5 points).**
- **Write a report including:**
  - **How you select data augmentation (transform) in the transform pool.**
  - **How you implement the SimCLR loss and explain why your SimCLR loss is computationally efficient and equivalent to the loss function in the paper.**
  - **Include the training loss curve and the downstream accuracy (15 points). Note that the logging logic is not provided, please implement them before you start training.**
---
Please DO NOT change the config provided. Only change the given code if you are confident that the change is necessary. It is recommended that you **use CPU session to debug** when GPU is not necessary since Colab only gives 12 hrs of free GPU access at a time. If you use up the GPU resource, you may consider to use Kaggle GPU resource. Thank you and good luck!

# Self-supervised learning: SimCLR

Self-supervised learning

1.   Design an auxiliary task.
2.   Train the base network on the auxiliary task.
3.   Evaluate on the down-stream task: Train a new decoder based on the trained encoder.

Specifically, as one of the most successful self-supervised learning algorithm, SimCLR, a contrastive learning algorithm is what we focus today. Below, we are going to implement SimCLR as an example of self-supervised learning.


<img src="https://camo.githubusercontent.com/35af3432fbe91c56a934b5ee58931b4848ab35043830c9dd6f08fa41e6eadbe7/68747470733a2f2f312e62702e626c6f6773706f742e636f6d2f2d2d764834504b704539596f2f586f3461324259657276492f414141414141414146704d2f766146447750584f79416f6b4143385868383532447a4f67457332324e68625877434c63424741735948512f73313630302f696d616765342e676966" width="650" height="650">

In [1]:
# Config
# Since, we are using jupyter notebook, we use easydict to micic argparse. Feel free to use other format of config
from easydict import EasyDict
import torch.nn as nn
from tqdm import tqdm
import torch

config = {
    'dataset_name': 'stl10',
    'workers': 1,
    'epochs': 70,
    'batch_size': 2048,
    'lr': 0.0003,
    'weight_decay': 1e-4,
    'seed': 4242,
    'fp16_precision': True,
    'out_dim': 128,
    'temperature': 0.5,
    'n_views': 2,
    'device': "cuda" if torch.cuda.is_available() else "cpu",
    # 'device': "mps" if torch.backends.mps.is_available() else "cpu",

}
args = EasyDict(config)

In [2]:
# Seed Setting
import random
import numpy as np

torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)

We are going to use [STL-10 dataset](https://cs.stanford.edu/~acoates/stl10/).

<img src="https://cs.stanford.edu/~acoates/stl10/images.png" width="450" height="450">

Overview

*   10 classes: airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck.
*   Images are **96x96** pixels, color.
*   500 training images (10 pre-defined folds), 800 test images per class.
*   100000 unlabeled images for unsupervised learning. These examples are extracted from a similar but broader distribution of images. For instance, it contains other types of animals (bears, rabbits, etc.) and vehicles (trains, buses, etc.) in addition to the ones in the labeled set.
*   Images were acquired from labeled examples on ImageNet.


## Preparation

Define a ResNet-18 and an additional MLP layer as the model training in the auxiliary task.

In [3]:
import torchvision.models as models

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.basemodel = models.resnet18(pretrained=False, num_classes=args.out_dim)
        self.fc_in_features = self.basemodel.fc.in_features
        self.backup_fc = None
        self.basemodel.fc = nn.Sequential(
          nn.Linear(self.fc_in_features, self.fc_in_features),
          nn.ReLU(),
          self.basemodel.fc,
        )

    def forward(self, x):
        return self.basemodel(x)

    def linear_probe(self):
        self.freeze_basemodel_encoder()
        self.backup_fc = self.basemodel.fc  # Backup the last Linear layer

        self.basemodel.fc = nn.Linear(self.fc_in_features, 10)


    def restore_backbone(self):
        self.basemodel.fc = self.backup_fc
        self.backup_fc = None

    def freeze_basemodel_encoder(self):
        # do not freeze the self.basemodel.fc weights
        for name, param in self.basemodel.named_parameters():
            if 'fc' not in name:
                param.requires_grad = False

# Step 1: Design the auxiliary task.
## construct the dataset

In [4]:
from torchvision import transforms, datasets

class View_sampler(object):
    """This class randomly sample two transforms from the list of transforms for the SimCLR to use. It is used in the SimCLRDataset.get_dataset."""

    def __init__(self, transforms, n_views=2):
        self.transforms = transforms
        self.n_views = n_views

    def __call__(self, x):
        return [self.transforms(x) for i in range(self.n_views)]


class SimCLRDataset:
    def __init__(self, root_folder="./datasets"):
        self.root_folder = root_folder

    @staticmethod
    def transforms_pool():
        data_transforms = transforms.Compose([
            transforms.RandomResizedCrop(96, scale=(0.2, 1.0)),  # Crop randomly the image and resize it to 96x96
            transforms.RandomHorizontalFlip(),  # Flip the image horizontally with a probability of 0.5
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # Randomly change the brightness, contrast, saturation and hue
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),  # Randomly convert image to grayscale
            # transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0)),  # Apply Gaussian Blur
            transforms.ToTensor(),  # Convert image to tensor
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
        ])

        return data_transforms

    def get_dataset(self):
        dataset_fn = lambda: datasets.STL10(self.root_folder, split='unlabeled', transform=View_sampler(self.transforms_pool(), args.n_views), download=True)
        return dataset_fn()

## Define dataloader, optimizer and scheduler

What is a scheduler?

A scheduler helps in optimizing the convergence, avoiding local minima, and potentially improving the model's performance on the task at hand. The learning rate is one of the most important hyperparameters for training neural networks, and finding an appropriate learning rate schedule can be crucial for your model's success.

<img src="https://miro.medium.com/v2/resize:fit:4800/format:webp/1*qe6nYlH8zsmUdScyHMhRCQ.png" width="1200" height="450">

Read more here: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

In [5]:
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

model = Net()
dataset = SimCLRDataset()
train_dataset = dataset.get_dataset()

dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)

optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

lr_scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0.00001)



Files already downloaded and verified


# Define the trainer


Automatic Mixed Precision (AMP) is a technique that aims to improve the speed and efficiency of training deep neural networks by leveraging mixed-precision training.

**Introduction to AMP**

AMP allows neural network training to use both single-precision (FP32) and half-precision (FP16) floating point arithmetic simultaneously. The main idea behind AMP is to perform certain operations in FP16 to exploit the faster arithmetic and reduced memory usage of lower-precision computing, while maintaining the critical parts of the computation in FP32 to ensure model accuracy and stability.

**Why We Cannot Always Use FP16**



*   **Numerical Stability**: FP16 has a smaller dynamic range and lower precision compared to FP32. This limitation can lead to numerical instability, such as underflows and overflows, particularly during operations that involve small gradient values or require high numerical precision. This can adversely affect the convergence and accuracy of the trained model.
*   **Selective Precision Requirements**: Certain operations and layers within neural networks are more sensitive to precision than others. For example, weight updates in optimizers might require FP32 to maintain accuracy over time. AMP strategies, therefore, involve selectively applying FP16 to parts of the computation where it can be beneficial without undermining the overall training process.

Below, we present how to include AMP logic in standard torch training procedure.

Before including AMP:
```python
for batch in data_loader:
    # Forward pass
    inputs, targets = batch
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)

    # Backward pass and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
```

After including AMP:
```python
from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

for batch in data_loader:
    inputs, targets = batch[0].cuda(), batch[1].cuda()

    # Forward pass
    with autocast():
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)

    # Backward pass and optimize
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
```


Read more here: https://pytorch.org/docs/stable/amp.html

## Implement loss function in the Trainer
The algorithm of SimCLR Loss function is as follows:
$$
\begin{aligned}
&\text{for all } i \in \{1, \ldots, 2N\} \text{ and } j \in \{1, \ldots, 2N\} \text{ do} \\
&\quad s_{i,j} = \frac{z_i^\top z_j}{\|z_i\|\|z_j\|} \quad \text{# pairwise similarity} \\
&\text{end for} \\
&\text{define } \ell(i,j) \text{ as } \ell(i,j) = -\log \left( \frac{\exp(s_{i,j} / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq j]} \exp(s_{i,k} / \tau)} \right) \\
&\mathcal{L} = \frac{1}{2N} \sum_{k=1}^{N} \left[\ell(2k-1, 2k) + \ell(2k, 2k-1)\right] \\
&\text{update networks } f \text{ and } g \text{ to minimize } \mathcal{L}
\end{aligned}
$$

Please fill the blanks in the loss function below. Hint: implement mask to avoid including self-self similarity, and postive pairs and negative pairs.

In [6]:
import torch
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
# from torch.utils.tensorboard import SummaryWriter  # for logging
import os


class Trainer():
    def __init__(self, *args, **kwargs):
        self.args = kwargs['args']
        self.model = kwargs['model'].to(self.args.device)
        self.optimizer = kwargs['optimizer']
        self.scheduler = kwargs['scheduler']
        self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)

        self.writer = SummaryWriter(log_dir='./logs/simclr')  # Initialize a SummaryWriter to log data for TensorBoard


    def loss(self, features):

        n_samples = features.shape[0] // 2  # because we have 2N samples where N is the batch size
        labels = torch.cat([torch.arange(n_samples) for _ in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float().to(self.args.device)
        features = F.normalize(features, dim=1)
        
        similarity_matrix = torch.matmul(features, features.T)
        
        # Mask for positive samples
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
        labels.masked_fill_(mask, 0)
        
        # Select the positive pairs (i, i+n_samples)
        positives = similarity_matrix[range(n_samples), range(n_samples, 2*n_samples)]
        negatives = similarity_matrix[~mask].view(n_samples, -1)

        # Log-sum-exp trick for numerical stability
        logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
        logits /= self.args.temperature

        labels = torch.zeros(n_samples, dtype=torch.long).to(self.args.device)
        loss = F.cross_entropy(logits, labels)
        
        return loss.mean()

    
    def train(self, dataloader):
        # implement GradScaler if AMP
        best_loss = 1e4
        scaler = GradScaler(enabled=self.args.fp16_precision)
        for epoch in range(self.args.epochs):
            for images, _ in tqdm(dataloader):
                images = torch.cat(images, dim=0)
                images = images.to(self.args.device)

                with autocast(enabled=self.args.fp16_precision):
                    features = self.model(images)
                    loss = self.loss(features)

                # features = self.model(images)
                # loss = self.loss(features)
                self.optimizer.zero_grad()

                scaler.scale(loss).backward()
                scaler.step(self.optimizer)
                scaler.update()

                # loss.backward()
                # self.optimizer.step()
                self.scheduler.step()

            # warmup for the first 10 epochs
            if epoch >= 10:
                self.scheduler.step()
            if epoch % 10 == 0 and epoch != 0:
                self.save_model(self.model, f"model_{epoch}.pth")
            # save the lowest loss model
            # feel free to implement your own logic to save the best model
            if loss < best_loss:
                best_loss = loss
                self.save_model(self.model, f"best_model.pth")
            print(f"Epoch {epoch}, Loss {loss.item()}")
            # self.writer.add_scalar("Loss/train", loss, epoch)
            self.writer.add_scalar("Loss/Train", loss, epoch)
            avg_epoch_loss = loss / len(dataloader)
            self.writer.add_scalar("AvgLoss/Train", avg_epoch_loss, epoch)
            self.writer.close()
            
        return self.model

    def save_model(self, model, path):
        torch.save(model.state_dict(), path)

# Step 2: Train the base network on the auxiliary task for 70 epoch and save the best model you have for evaluation.

Check if the training loss drops over time and try to capture other possible bug using logging tools. Each epoch should take around 7 minutes. The loss should be expected to around 7.4.

In [7]:
trainer = Trainer(args=args, model=model, optimizer=optimizer, scheduler=lr_scheduler)
trainer.train(dataloader)

100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 0, Loss 7.990354537963867


100%|██████████| 48/48 [03:35<00:00,  4.50s/it]


Epoch 1, Loss 7.854154586791992


100%|██████████| 48/48 [03:37<00:00,  4.54s/it]


Epoch 2, Loss 7.769664764404297


100%|██████████| 48/48 [03:35<00:00,  4.48s/it]


Epoch 3, Loss 7.738748550415039


100%|██████████| 48/48 [03:36<00:00,  4.50s/it]


Epoch 4, Loss 7.71990966796875


100%|██████████| 48/48 [03:36<00:00,  4.52s/it]


Epoch 5, Loss 7.674175262451172


100%|██████████| 48/48 [03:36<00:00,  4.51s/it]


Epoch 6, Loss 7.664894104003906


100%|██████████| 48/48 [03:36<00:00,  4.52s/it]


Epoch 7, Loss 7.641824722290039


100%|██████████| 48/48 [03:36<00:00,  4.50s/it]


Epoch 8, Loss 7.609840393066406


100%|██████████| 48/48 [03:36<00:00,  4.52s/it]


Epoch 9, Loss 7.594123840332031


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 10, Loss 7.585121154785156


100%|██████████| 48/48 [03:35<00:00,  4.48s/it]


Epoch 11, Loss 7.590019226074219


100%|██████████| 48/48 [03:34<00:00,  4.47s/it]


Epoch 12, Loss 7.579301834106445


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 13, Loss 7.55980110168457


100%|██████████| 48/48 [03:38<00:00,  4.55s/it]


Epoch 14, Loss 7.571329116821289


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 15, Loss 7.548988342285156


100%|██████████| 48/48 [03:36<00:00,  4.51s/it]


Epoch 16, Loss 7.536582946777344


100%|██████████| 48/48 [03:36<00:00,  4.51s/it]


Epoch 17, Loss 7.524885177612305


100%|██████████| 48/48 [03:34<00:00,  4.47s/it]


Epoch 18, Loss 7.501638412475586


100%|██████████| 48/48 [03:36<00:00,  4.51s/it]


Epoch 19, Loss 7.513612747192383


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 20, Loss 7.506864547729492


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 21, Loss 7.513397216796875


100%|██████████| 48/48 [03:35<00:00,  4.50s/it]


Epoch 22, Loss 7.51617431640625


100%|██████████| 48/48 [03:36<00:00,  4.52s/it]


Epoch 23, Loss 7.509376525878906


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 24, Loss 7.503721237182617


100%|██████████| 48/48 [03:36<00:00,  4.50s/it]


Epoch 25, Loss 7.487756729125977


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 26, Loss 7.490514755249023


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 27, Loss 7.468572616577148


100%|██████████| 48/48 [03:36<00:00,  4.51s/it]


Epoch 28, Loss 7.452425003051758


100%|██████████| 48/48 [03:36<00:00,  4.50s/it]


Epoch 29, Loss 7.457487106323242


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 30, Loss 7.472375869750977


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 31, Loss 7.4730072021484375


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 32, Loss 7.482873916625977


100%|██████████| 48/48 [03:34<00:00,  4.47s/it]


Epoch 33, Loss 7.460201263427734


100%|██████████| 48/48 [03:35<00:00,  4.48s/it]


Epoch 34, Loss 7.470294952392578


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 35, Loss 7.46534538269043


100%|██████████| 48/48 [03:35<00:00,  4.50s/it]


Epoch 36, Loss 7.451847076416016


100%|██████████| 48/48 [03:37<00:00,  4.52s/it]


Epoch 37, Loss 7.423942565917969


100%|██████████| 48/48 [03:36<00:00,  4.50s/it]


Epoch 38, Loss 7.442647933959961


100%|██████████| 48/48 [03:36<00:00,  4.50s/it]


Epoch 39, Loss 7.4392852783203125


100%|██████████| 48/48 [03:36<00:00,  4.50s/it]


Epoch 40, Loss 7.442821502685547


100%|██████████| 48/48 [03:35<00:00,  4.50s/it]


Epoch 41, Loss 7.44810676574707


100%|██████████| 48/48 [03:35<00:00,  4.50s/it]


Epoch 42, Loss 7.448490142822266


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 43, Loss 7.433841705322266


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 44, Loss 7.4438323974609375


100%|██████████| 48/48 [03:35<00:00,  4.50s/it]


Epoch 45, Loss 7.4406585693359375


100%|██████████| 48/48 [03:36<00:00,  4.50s/it]


Epoch 46, Loss 7.4363250732421875


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 47, Loss 7.433971405029297


100%|██████████| 48/48 [03:36<00:00,  4.52s/it]


Epoch 48, Loss 7.414365768432617


100%|██████████| 48/48 [03:35<00:00,  4.48s/it]


Epoch 49, Loss 7.415159225463867


100%|██████████| 48/48 [03:36<00:00,  4.50s/it]


Epoch 50, Loss 7.4303131103515625


100%|██████████| 48/48 [03:35<00:00,  4.48s/it]


Epoch 51, Loss 7.4246826171875


100%|██████████| 48/48 [03:36<00:00,  4.51s/it]


Epoch 52, Loss 7.427026748657227


100%|██████████| 48/48 [03:36<00:00,  4.52s/it]


Epoch 53, Loss 7.419414520263672


100%|██████████| 48/48 [03:35<00:00,  4.48s/it]


Epoch 54, Loss 7.4125213623046875


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 55, Loss 7.417049407958984


100%|██████████| 48/48 [03:36<00:00,  4.51s/it]


Epoch 56, Loss 7.4122161865234375


100%|██████████| 48/48 [03:35<00:00,  4.48s/it]


Epoch 57, Loss 7.411289215087891


100%|██████████| 48/48 [03:34<00:00,  4.48s/it]


Epoch 58, Loss 7.392814636230469


100%|██████████| 48/48 [03:36<00:00,  4.51s/it]


Epoch 59, Loss 7.404624938964844


100%|██████████| 48/48 [03:36<00:00,  4.52s/it]


Epoch 60, Loss 7.403038024902344


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 61, Loss 7.412336349487305


100%|██████████| 48/48 [03:35<00:00,  4.48s/it]


Epoch 62, Loss 7.40144157409668


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 63, Loss 7.405712127685547


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 64, Loss 7.404083251953125


100%|██████████| 48/48 [03:35<00:00,  4.48s/it]


Epoch 65, Loss 7.407312393188477


100%|██████████| 48/48 [03:35<00:00,  4.49s/it]


Epoch 66, Loss 7.392786026000977


100%|██████████| 48/48 [03:36<00:00,  4.52s/it]


Epoch 67, Loss 7.377157211303711


100%|██████████| 48/48 [03:35<00:00,  4.48s/it]


Epoch 68, Loss 7.379179000854492


100%|██████████| 48/48 [03:34<00:00,  4.48s/it]

Epoch 69, Loss 7.378019332885742





Net(
  (basemodel): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runni

# Step 3: Evaluate on the down-stream task: Train a new MLP decoder based on the trained encoder.

This finetune process should be way faster than the previous one. The expected Top-1 accuracy should be around 57% and the Top-5 accuracy should be around 97%. Getting this results is normal because linear prob is just a projection layer usually recognized as not having representation ability. To achieve the performance mentioned in the paper, we need larger dataset and more powerful GPU and longer time (about 1000 epoches during the pretraining stage).

In [9]:
class linear_prob_Trainer:
    def __init__(self, *args, **kwargs):
        self.args = kwargs["args"]
        self.model = kwargs["model"].to(self.args.device)
        self.optimizer = kwargs["optimizer"]
        self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
        self.train_dataset = datasets.STL10(
            "./data", split="train", download=True, transform=transforms.ToTensor()
        )

        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.batch_size,
            num_workers=1,
            drop_last=False,
        )

        self.test_dataset = datasets.STL10(
            "./data", split="test", download=True, transform=transforms.ToTensor()
        )

        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.args.batch_size,
            num_workers=1,
            drop_last=False,
        )

        self.writer = SummaryWriter(log_dir='./logs/linear_probe')

    def accuracy(self, output, target, topk=(1,)):
        with torch.no_grad():
            maxk = max(topk)
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []
            for k in topk:
                correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
                res.append(correct_k.mul_(100.0 / batch_size))
            return res

    def train(self, dataloader):
        for epoch in range(100):
            top1_train_accuracy = 0
            for images, labels in tqdm(dataloader):
                images, labels = images.to(self.args.device), labels.to(
                    self.args.device
                )
                logits = self.model(images)
                loss = self.criterion(logits, labels)
                top1 = self.accuracy(logits, labels, topk=(1,))
                top1_train_accuracy += top1[0]

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                print(
                    f"Epoch: {epoch}, Loss: {loss.item()}",
                    "Top1 Train Accuracy: ",
                    top1_train_accuracy.item() / len(dataloader),
                )
            
            self.writer.add_scalar("Loss/Train", loss, epoch)
            avg_epoch_loss = loss / len(dataloader)
            self.writer.add_scalar("AvgLoss/Train", avg_epoch_loss, epoch)
            self.writer.add_scalar("AccuracyTop1/Train", top1_train_accuracy, epoch)
            avg_train_accuracy = top1_train_accuracy / len(dataloader.dataset)
            self.writer.add_scalar("AvgAccuracyTop1/Train", avg_train_accuracy, epoch)

        self.writer.close()
        return self.model

    def test(self, dataloader):
        with torch.no_grad():
            model.eval()
            top1_test_accuracy = 0
            top5_test_accuracy = 0
            for images, labels in tqdm(dataloader):
                images, labels = images.to(self.args.device), labels.to(
                    self.args.device
                )
                logits = self.model(images)
                top1 = self.accuracy(logits, labels, topk=(1,))
                top1_test_accuracy += top1[0]
                top5 = self.accuracy(logits, labels, topk=(5,))
                top5_test_accuracy += top5[0]
            print("Top1 Test Accuracy: ", top1_test_accuracy.item() / len(dataloader))
            print("Top5 Test Accuracy: ", top5_test_accuracy.item() / len(dataloader))

            self.writer.add_scalar("AccuracyTop1/Test", top1_test_accuracy, 0)
            self.writer.add_scalar("AccuracyTop5/Test", top5_test_accuracy, 0)
            avg_top1_accuracy = top1_test_accuracy / len(dataloader.dataset)
            avg_top5_accuracy = top5_test_accuracy / len(dataloader.dataset)
            self.writer.add_scalar("AvgAccuracyTop1/Test", avg_top1_accuracy, 0)
            self.writer.add_scalar("AvgAccuracyTop5/Test", avg_top5_accuracy, 0)
            self.writer.close()

            return

model.linear_probe()
linear_probe_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
linear_prob_trainer = linear_prob_Trainer(args=args, model=model, optimizer=linear_probe_optimizer)
model = linear_prob_trainer.train(linear_prob_trainer.train_loader)
linear_prob_trainer.test(linear_prob_trainer.test_loader)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz


100%|██████████| 2640397119/2640397119 [02:22<00:00, 18553394.38it/s]


Extracting ./data/stl10_binary.tar.gz to ./data
Files already downloaded and verified


 33%|███▎      | 1/3 [00:00<00:01,  1.06it/s]

Epoch: 0, Loss: 2.5003902912139893 Top1 Train Accuracy:  3.0436197916666665


100%|██████████| 3/3 [00:01<00:00,  1.75it/s]


Epoch: 0, Loss: 2.4712836742401123 Top1 Train Accuracy:  5.777994791666667
Epoch: 0, Loss: 2.4010021686553955 Top1 Train Accuracy:  8.469735463460287


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 1, Loss: 2.3463659286499023 Top1 Train Accuracy:  3.2063802083333335


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 1, Loss: 2.3157758712768555 Top1 Train Accuracy:  6.673177083333333
Epoch: 1, Loss: 2.255971908569336 Top1 Train Accuracy:  10.72922388712565


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 2, Loss: 2.212472677230835 Top1 Train Accuracy:  4.964192708333333


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 2, Loss: 2.1794545650482178 Top1 Train Accuracy:  10.904947916666666
Epoch: 2, Loss: 2.1281771659851074 Top1 Train Accuracy:  17.615862528483074


 33%|███▎      | 1/3 [00:00<00:01,  1.15it/s]

Epoch: 3, Loss: 2.093695640563965 Top1 Train Accuracy:  8.154296875


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 3, Loss: 2.0583786964416504 Top1 Train Accuracy:  17.3828125
Epoch: 3, Loss: 2.0140135288238525 Top1 Train Accuracy:  27.781041463216145


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 4, Loss: 1.9868627786636353 Top1 Train Accuracy:  11.002604166666666


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 4, Loss: 1.9498728513717651 Top1 Train Accuracy:  23.209635416666668
Epoch: 4, Loss: 1.9111865758895874 Top1 Train Accuracy:  36.44709777832031


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 5, Loss: 1.8902150392532349 Top1 Train Accuracy:  13.118489583333334


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 5, Loss: 1.8524482250213623 Top1 Train Accuracy:  27.083333333333332
Epoch: 5, Loss: 1.8185011148452759 Top1 Train Accuracy:  41.72197469075521


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 6, Loss: 1.8029725551605225 Top1 Train Accuracy:  14.501953125


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 6, Loss: 1.7653416395187378 Top1 Train Accuracy:  29.801432291666668
Epoch: 6, Loss: 1.735379934310913 Top1 Train Accuracy:  45.47252400716146


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 7, Loss: 1.7248213291168213 Top1 Train Accuracy:  15.445963541666666


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 7, Loss: 1.6880297660827637 Top1 Train Accuracy:  31.754557291666668
Epoch: 7, Loss: 1.661399483680725 Top1 Train Accuracy:  48.53184509277344


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 8, Loss: 1.6554689407348633 Top1 Train Accuracy:  16.520182291666668


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 8, Loss: 1.6198577880859375 Top1 Train Accuracy:  33.756510416666664
Epoch: 8, Loss: 1.5959713459014893 Top1 Train Accuracy:  51.086893717447914


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 9, Loss: 1.5943751335144043 Top1 Train Accuracy:  17.333984375


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 9, Loss: 1.5598915815353394 Top1 Train Accuracy:  34.879557291666664
Epoch: 9, Loss: 1.5382566452026367 Top1 Train Accuracy:  52.689290364583336


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 10, Loss: 1.54073166847229 Top1 Train Accuracy:  17.561848958333332


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 10, Loss: 1.5070290565490723 Top1 Train Accuracy:  35.693359375
Epoch: 10, Loss: 1.4873040914535522 Top1 Train Accuracy:  53.94557189941406


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 11, Loss: 1.4936262369155884 Top1 Train Accuracy:  17.838541666666668


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 11, Loss: 1.460228443145752 Top1 Train Accuracy:  36.279296875
Epoch: 11, Loss: 1.4422461986541748 Top1 Train Accuracy:  54.97398885091146


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 12, Loss: 1.4522136449813843 Top1 Train Accuracy:  18.26171875


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 12, Loss: 1.4186627864837646 Top1 Train Accuracy:  37.125651041666664
Epoch: 12, Loss: 1.4023782014846802 Top1 Train Accuracy:  56.41031392415365


 33%|███▎      | 1/3 [00:00<00:01,  1.15it/s]

Epoch: 13, Loss: 1.4157682657241821 Top1 Train Accuracy:  18.473307291666668


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 13, Loss: 1.3817014694213867 Top1 Train Accuracy:  37.565104166666664
Epoch: 13, Loss: 1.3671082258224487 Top1 Train Accuracy:  57.107879638671875


 33%|███▎      | 1/3 [00:00<00:01,  1.15it/s]

Epoch: 14, Loss: 1.3836380243301392 Top1 Train Accuracy:  18.603515625


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 14, Loss: 1.3488209247589111 Top1 Train Accuracy:  37.825520833333336
Epoch: 14, Loss: 1.3358781337738037 Top1 Train Accuracy:  57.62640380859375


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 15, Loss: 1.3552123308181763 Top1 Train Accuracy:  18.84765625


100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Epoch: 15, Loss: 1.3195406198501587 Top1 Train Accuracy:  38.346354166666664
Epoch: 15, Loss: 1.3081468343734741 Top1 Train Accuracy:  58.220987955729164


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 16, Loss: 1.3299386501312256 Top1 Train Accuracy:  18.84765625


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 16, Loss: 1.293418288230896 Top1 Train Accuracy:  38.525390625
Epoch: 16, Loss: 1.2834159135818481 Top1 Train Accuracy:  58.69500732421875


 33%|███▎      | 1/3 [00:00<00:01,  1.12it/s]

Epoch: 17, Loss: 1.3073550462722778 Top1 Train Accuracy:  18.961588541666668


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 17, Loss: 1.2700592279434204 Top1 Train Accuracy:  38.704427083333336
Epoch: 17, Loss: 1.2612534761428833 Top1 Train Accuracy:  59.05841064453125


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 18, Loss: 1.2870965003967285 Top1 Train Accuracy:  19.108072916666668


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 18, Loss: 1.2491127252578735 Top1 Train Accuracy:  38.932291666666664
Epoch: 18, Loss: 1.241295576095581 Top1 Train Accuracy:  59.36002095540365


 33%|███▎      | 1/3 [00:00<00:01,  1.15it/s]

Epoch: 19, Loss: 1.2688714265823364 Top1 Train Accuracy:  19.352213541666668


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 19, Loss: 1.230261206626892 Top1 Train Accuracy:  39.420572916666664
Epoch: 19, Loss: 1.223235011100769 Top1 Train Accuracy:  59.885172526041664


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 20, Loss: 1.2524350881576538 Top1 Train Accuracy:  19.466145833333332


100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Epoch: 20, Loss: 1.213214635848999 Top1 Train Accuracy:  39.697265625
Epoch: 20, Loss: 1.2068142890930176 Top1 Train Accuracy:  60.19874064127604


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 21, Loss: 1.2375718355178833 Top1 Train Accuracy:  19.563802083333332


100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Epoch: 21, Loss: 1.1977179050445557 Top1 Train Accuracy:  39.908854166666664
Epoch: 21, Loss: 1.1918219327926636 Top1 Train Accuracy:  60.41032918294271


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 22, Loss: 1.2240897417068481 Top1 Train Accuracy:  19.694010416666668


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 22, Loss: 1.183556318283081 Top1 Train Accuracy:  40.13671875
Epoch: 22, Loss: 1.178086280822754 Top1 Train Accuracy:  60.785685221354164


 33%|███▎      | 1/3 [00:00<00:01,  1.12it/s]

Epoch: 23, Loss: 1.2118158340454102 Top1 Train Accuracy:  19.742838541666668


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 23, Loss: 1.170555591583252 Top1 Train Accuracy:  40.234375
Epoch: 23, Loss: 1.165465235710144 Top1 Train Accuracy:  61.03083292643229


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 24, Loss: 1.2005937099456787 Top1 Train Accuracy:  19.807942708333332


100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Epoch: 24, Loss: 1.1585736274719238 Top1 Train Accuracy:  40.364583333333336
Epoch: 24, Loss: 1.1538331508636475 Top1 Train Accuracy:  61.197916666666664


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 25, Loss: 1.1902847290039062 Top1 Train Accuracy:  19.873046875


100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Epoch: 25, Loss: 1.1474947929382324 Top1 Train Accuracy:  40.462239583333336
Epoch: 25, Loss: 1.1430789232254028 Top1 Train Accuracy:  61.295572916666664


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 26, Loss: 1.1807698011398315 Top1 Train Accuracy:  19.954427083333332


100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Epoch: 26, Loss: 1.1372222900390625 Top1 Train Accuracy:  40.641276041666664
Epoch: 26, Loss: 1.1331007480621338 Top1 Train Accuracy:  61.622100830078125


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 27, Loss: 1.171951174736023 Top1 Train Accuracy:  19.986979166666668


100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Epoch: 27, Loss: 1.1276754140853882 Top1 Train Accuracy:  40.738932291666664
Epoch: 27, Loss: 1.123807668685913 Top1 Train Accuracy:  61.756632486979164


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 28, Loss: 1.16374933719635 Top1 Train Accuracy:  19.986979166666668


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 28, Loss: 1.1187834739685059 Top1 Train Accuracy:  40.8203125
Epoch: 28, Loss: 1.1151198148727417 Top1 Train Accuracy:  61.985504150390625


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 29, Loss: 1.1560989618301392 Top1 Train Accuracy:  20.003255208333332


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 29, Loss: 1.1104822158813477 Top1 Train Accuracy:  40.885416666666664
Epoch: 29, Loss: 1.106967568397522 Top1 Train Accuracy:  62.01373291015625


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 30, Loss: 1.1489448547363281 Top1 Train Accuracy:  20.100911458333332


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 30, Loss: 1.1027121543884277 Top1 Train Accuracy:  40.983072916666664
Epoch: 30, Loss: 1.0992928743362427 Top1 Train Accuracy:  62.07451883951823


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 31, Loss: 1.1422377824783325 Top1 Train Accuracy:  20.068359375


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 31, Loss: 1.0954208374023438 Top1 Train Accuracy:  40.966796875
Epoch: 31, Loss: 1.092048168182373 Top1 Train Accuracy:  62.05824279785156


 33%|███▎      | 1/3 [00:00<00:01,  1.15it/s]

Epoch: 32, Loss: 1.1359343528747559 Top1 Train Accuracy:  20.084635416666668


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 32, Loss: 1.0885603427886963 Top1 Train Accuracy:  41.048177083333336
Epoch: 32, Loss: 1.0851949453353882 Top1 Train Accuracy:  62.176493326822914


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 33, Loss: 1.1299948692321777 Top1 Train Accuracy:  20.166015625


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 33, Loss: 1.0820887088775635 Top1 Train Accuracy:  41.178385416666664
Epoch: 33, Loss: 1.0786999464035034 Top1 Train Accuracy:  62.41732279459635


 33%|███▎      | 1/3 [00:00<00:01,  1.15it/s]

Epoch: 34, Loss: 1.1243839263916016 Top1 Train Accuracy:  20.21484375


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 34, Loss: 1.075968861579895 Top1 Train Accuracy:  41.30859375
Epoch: 34, Loss: 1.0725349187850952 Top1 Train Accuracy:  62.695027669270836


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 35, Loss: 1.1190693378448486 Top1 Train Accuracy:  20.263671875


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 35, Loss: 1.07016921043396 Top1 Train Accuracy:  41.30859375
Epoch: 35, Loss: 1.066672682762146 Top1 Train Accuracy:  62.695027669270836


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 36, Loss: 1.1140241622924805 Top1 Train Accuracy:  20.3125


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 36, Loss: 1.0646615028381348 Top1 Train Accuracy:  41.389973958333336
Epoch: 36, Loss: 1.0610886812210083 Top1 Train Accuracy:  62.96076965332031


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 37, Loss: 1.1092251539230347 Top1 Train Accuracy:  20.3125


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 37, Loss: 1.0594217777252197 Top1 Train Accuracy:  41.389973958333336
Epoch: 37, Loss: 1.055759310722351 Top1 Train Accuracy:  62.99764506022135


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 38, Loss: 1.1046520471572876 Top1 Train Accuracy:  20.3125


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 38, Loss: 1.054429531097412 Top1 Train Accuracy:  41.438802083333336
Epoch: 38, Loss: 1.0506632328033447 Top1 Train Accuracy:  63.00959777832031


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 39, Loss: 1.1002863645553589 Top1 Train Accuracy:  20.3125


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 39, Loss: 1.0496665239334106 Top1 Train Accuracy:  41.520182291666664
Epoch: 39, Loss: 1.0457810163497925 Top1 Train Accuracy:  63.16472371419271


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 40, Loss: 1.0961123704910278 Top1 Train Accuracy:  20.328776041666668


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 40, Loss: 1.0451161861419678 Top1 Train Accuracy:  41.666666666666664
Epoch: 40, Loss: 1.0410958528518677 Top1 Train Accuracy:  63.34808349609375


 33%|███▎      | 1/3 [00:00<00:01,  1.15it/s]

Epoch: 41, Loss: 1.092114806175232 Top1 Train Accuracy:  20.3125


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 41, Loss: 1.040763258934021 Top1 Train Accuracy:  41.715494791666664
Epoch: 41, Loss: 1.0365928411483765 Top1 Train Accuracy:  63.47065734863281


 33%|███▎      | 1/3 [00:00<00:01,  1.17it/s]

Epoch: 42, Loss: 1.0882798433303833 Top1 Train Accuracy:  20.377604166666668


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 42, Loss: 1.0365941524505615 Top1 Train Accuracy:  41.813151041666664
Epoch: 42, Loss: 1.0322593450546265 Top1 Train Accuracy:  63.605183919270836


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 43, Loss: 1.0845953226089478 Top1 Train Accuracy:  20.458984375


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 43, Loss: 1.0325957536697388 Top1 Train Accuracy:  41.910807291666664
Epoch: 43, Loss: 1.0280845165252686 Top1 Train Accuracy:  63.702840169270836


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 44, Loss: 1.0810496807098389 Top1 Train Accuracy:  20.475260416666668


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 44, Loss: 1.0287559032440186 Top1 Train Accuracy:  41.927083333333336
Epoch: 44, Loss: 1.0240578651428223 Top1 Train Accuracy:  63.68224589029948


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 45, Loss: 1.077634334564209 Top1 Train Accuracy:  20.491536458333332


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 45, Loss: 1.025063395500183 Top1 Train Accuracy:  41.89453125
Epoch: 45, Loss: 1.0201703310012817 Top1 Train Accuracy:  63.612823486328125


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 46, Loss: 1.074339509010315 Top1 Train Accuracy:  20.524088541666668


100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Epoch: 46, Loss: 1.0215080976486206 Top1 Train Accuracy:  41.959635416666664
Epoch: 46, Loss: 1.0164132118225098 Top1 Train Accuracy:  63.67792765299479


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 47, Loss: 1.0711581707000732 Top1 Train Accuracy:  20.589192708333332


100%|██████████| 3/3 [00:01<00:00,  1.84it/s]


Epoch: 47, Loss: 1.0180810689926147 Top1 Train Accuracy:  42.008463541666664
Epoch: 47, Loss: 1.0127785205841064 Top1 Train Accuracy:  63.837371826171875


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 48, Loss: 1.068082571029663 Top1 Train Accuracy:  20.638020833333332


100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Epoch: 48, Loss: 1.0147737264633179 Top1 Train Accuracy:  42.122395833333336
Epoch: 48, Loss: 1.0092582702636719 Top1 Train Accuracy:  63.95130411783854


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 49, Loss: 1.0651063919067383 Top1 Train Accuracy:  20.654296875


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 49, Loss: 1.0115797519683838 Top1 Train Accuracy:  42.220052083333336
Epoch: 49, Loss: 1.0058456659317017 Top1 Train Accuracy:  64.08583577473958


 33%|███▎      | 1/3 [00:00<00:01,  1.15it/s]

Epoch: 50, Loss: 1.0622235536575317 Top1 Train Accuracy:  20.670572916666668


100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Epoch: 50, Loss: 1.0084919929504395 Top1 Train Accuracy:  42.333984375
Epoch: 50, Loss: 1.0025343894958496 Top1 Train Accuracy:  64.19976806640625


 33%|███▎      | 1/3 [00:00<00:01,  1.16it/s]

Epoch: 51, Loss: 1.0594279766082764 Top1 Train Accuracy:  20.735677083333332


100%|██████████| 3/3 [00:01<00:00,  1.86it/s]


Epoch: 51, Loss: 1.0055047273635864 Top1 Train Accuracy:  42.447916666666664
Epoch: 51, Loss: 0.999318540096283 Top1 Train Accuracy:  64.31370035807292


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 52, Loss: 1.0567150115966797 Top1 Train Accuracy:  20.751953125


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 52, Loss: 1.0026122331619263 Top1 Train Accuracy:  42.447916666666664
Epoch: 52, Loss: 0.996192991733551 Top1 Train Accuracy:  64.35057067871094


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 53, Loss: 1.0540794134140015 Top1 Train Accuracy:  20.817057291666668


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 53, Loss: 0.9998091459274292 Top1 Train Accuracy:  42.561848958333336
Epoch: 53, Loss: 0.9931525588035583 Top1 Train Accuracy:  64.50137837727864


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 54, Loss: 1.051517367362976 Top1 Train Accuracy:  20.833333333333332


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 54, Loss: 0.997090756893158 Top1 Train Accuracy:  42.578125
Epoch: 54, Loss: 0.9901930689811707 Top1 Train Accuracy:  64.48077901204427


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 55, Loss: 1.0490241050720215 Top1 Train Accuracy:  20.849609375


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 55, Loss: 0.9944526553153992 Top1 Train Accuracy:  42.594401041666664
Epoch: 55, Loss: 0.987309992313385 Top1 Train Accuracy:  64.49705505371094


 33%|███▎      | 1/3 [00:00<00:01,  1.12it/s]

Epoch: 56, Loss: 1.0465964078903198 Top1 Train Accuracy:  20.882161458333332


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 56, Loss: 0.9918901324272156 Top1 Train Accuracy:  42.643229166666664
Epoch: 56, Loss: 0.984499990940094 Top1 Train Accuracy:  64.50901285807292


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 57, Loss: 1.0442310571670532 Top1 Train Accuracy:  20.914713541666668


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 57, Loss: 0.9893994927406311 Top1 Train Accuracy:  42.692057291666664
Epoch: 57, Loss: 0.9817589521408081 Top1 Train Accuracy:  64.59471130371094


 33%|███▎      | 1/3 [00:00<00:01,  1.10it/s]

Epoch: 58, Loss: 1.0419245958328247 Top1 Train Accuracy:  20.914713541666668


100%|██████████| 3/3 [00:01<00:00,  1.79it/s]


Epoch: 58, Loss: 0.9869771599769592 Top1 Train Accuracy:  42.740885416666664
Epoch: 58, Loss: 0.9790838360786438 Top1 Train Accuracy:  64.64353942871094


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 59, Loss: 1.0396744012832642 Top1 Train Accuracy:  20.8984375


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 59, Loss: 0.9846196174621582 Top1 Train Accuracy:  42.740885416666664
Epoch: 59, Loss: 0.9764712452888489 Top1 Train Accuracy:  64.68041483561198


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 60, Loss: 1.0374773740768433 Top1 Train Accuracy:  20.914713541666668


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 60, Loss: 0.9823238253593445 Top1 Train Accuracy:  42.757161458333336
Epoch: 60, Loss: 0.9739185571670532 Top1 Train Accuracy:  64.80731201171875


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 61, Loss: 1.0353314876556396 Top1 Train Accuracy:  20.930989583333332


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 61, Loss: 0.980086624622345 Top1 Train Accuracy:  42.805989583333336
Epoch: 61, Loss: 0.9714226722717285 Top1 Train Accuracy:  64.8192647298177


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 62, Loss: 1.0332340002059937 Top1 Train Accuracy:  20.947265625


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 62, Loss: 0.9779057502746582 Top1 Train Accuracy:  42.887369791666664
Epoch: 62, Loss: 0.9689812064170837 Top1 Train Accuracy:  64.90064493815105


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 63, Loss: 1.031182885169983 Top1 Train Accuracy:  20.930989583333332


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 63, Loss: 0.975778341293335 Top1 Train Accuracy:  42.919921875
Epoch: 63, Loss: 0.966591477394104 Top1 Train Accuracy:  64.97007242838542


 33%|███▎      | 1/3 [00:00<00:01,  1.12it/s]

Epoch: 64, Loss: 1.029175877571106 Top1 Train Accuracy:  20.930989583333332


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 64, Loss: 0.9737019538879395 Top1 Train Accuracy:  42.96875
Epoch: 64, Loss: 0.9642518162727356 Top1 Train Accuracy:  65.09264119466145


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 65, Loss: 1.0272114276885986 Top1 Train Accuracy:  20.947265625


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 65, Loss: 0.9716746211051941 Top1 Train Accuracy:  43.033854166666664
Epoch: 65, Loss: 0.961959719657898 Top1 Train Accuracy:  65.15774536132812


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 66, Loss: 1.0252870321273804 Top1 Train Accuracy:  20.963541666666668


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 66, Loss: 0.9696940779685974 Top1 Train Accuracy:  43.050130208333336
Epoch: 66, Loss: 0.9597131013870239 Top1 Train Accuracy:  65.13715108235677


 33%|███▎      | 1/3 [00:00<00:01,  1.12it/s]

Epoch: 67, Loss: 1.0234017372131348 Top1 Train Accuracy:  20.963541666666668


100%|██████████| 3/3 [00:01<00:00,  1.77it/s]


Epoch: 67, Loss: 0.9677582383155823 Top1 Train Accuracy:  43.082682291666664
Epoch: 67, Loss: 0.9575104117393494 Top1 Train Accuracy:  65.20657348632812


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 68, Loss: 1.02155339717865 Top1 Train Accuracy:  21.077473958333332


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 68, Loss: 0.9658653736114502 Top1 Train Accuracy:  43.229166666666664
Epoch: 68, Loss: 0.9553496241569519 Top1 Train Accuracy:  65.35305786132812


 33%|███▎      | 1/3 [00:00<00:01,  1.12it/s]

Epoch: 69, Loss: 1.0197412967681885 Top1 Train Accuracy:  21.09375


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 69, Loss: 0.9640135765075684 Top1 Train Accuracy:  43.26171875
Epoch: 69, Loss: 0.9532294869422913 Top1 Train Accuracy:  65.38560994466145


 33%|███▎      | 1/3 [00:00<00:01,  1.12it/s]

Epoch: 70, Loss: 1.017963171005249 Top1 Train Accuracy:  21.09375


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 70, Loss: 0.9622011184692383 Top1 Train Accuracy:  43.294270833333336
Epoch: 70, Loss: 0.9511480927467346 Top1 Train Accuracy:  65.4181620279948


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 71, Loss: 1.0162180662155151 Top1 Train Accuracy:  21.09375


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 71, Loss: 0.9604265689849854 Top1 Train Accuracy:  43.277994791666664
Epoch: 71, Loss: 0.9491040110588074 Top1 Train Accuracy:  65.43876139322917


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 72, Loss: 1.0145047903060913 Top1 Train Accuracy:  21.175130208333332


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 72, Loss: 0.9586884379386902 Top1 Train Accuracy:  43.424479166666664
Epoch: 72, Loss: 0.9470958709716797 Top1 Train Accuracy:  65.69586690266927


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 73, Loss: 1.012822151184082 Top1 Train Accuracy:  21.158854166666668


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 73, Loss: 0.9569849967956543 Top1 Train Accuracy:  43.408203125
Epoch: 73, Loss: 0.9451226592063904 Top1 Train Accuracy:  65.64271545410156


 33%|███▎      | 1/3 [00:00<00:01,  1.07it/s]

Epoch: 74, Loss: 1.0111688375473022 Top1 Train Accuracy:  21.19140625


100%|██████████| 3/3 [00:01<00:00,  1.76it/s]


Epoch: 74, Loss: 0.9553154110908508 Top1 Train Accuracy:  43.424479166666664
Epoch: 74, Loss: 0.9431825280189514 Top1 Train Accuracy:  65.65899149576823


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 75, Loss: 1.009544014930725 Top1 Train Accuracy:  21.223958333333332


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 75, Loss: 0.9536781907081604 Top1 Train Accuracy:  43.489583333333336
Epoch: 75, Loss: 0.9412745833396912 Top1 Train Accuracy:  65.72409566243489


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 76, Loss: 1.0079466104507446 Top1 Train Accuracy:  21.207682291666668


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 76, Loss: 0.9520721435546875 Top1 Train Accuracy:  43.505859375
Epoch: 76, Loss: 0.9393981099128723 Top1 Train Accuracy:  65.77724711100261


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 77, Loss: 1.0063755512237549 Top1 Train Accuracy:  21.207682291666668


100%|██████████| 3/3 [00:01<00:00,  1.78it/s]


Epoch: 77, Loss: 0.9504960775375366 Top1 Train Accuracy:  43.505859375
Epoch: 77, Loss: 0.9375514984130859 Top1 Train Accuracy:  65.77724711100261


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 78, Loss: 1.004830241203308 Top1 Train Accuracy:  21.207682291666668


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 78, Loss: 0.9489490985870361 Top1 Train Accuracy:  43.522135416666664
Epoch: 78, Loss: 0.93573397397995 Top1 Train Accuracy:  65.8303934733073


 33%|███▎      | 1/3 [00:00<00:01,  1.10it/s]

Epoch: 79, Loss: 1.0033096075057983 Top1 Train Accuracy:  21.240234375


100%|██████████| 3/3 [00:01<00:00,  1.79it/s]


Epoch: 79, Loss: 0.947429895401001 Top1 Train Accuracy:  43.603515625
Epoch: 79, Loss: 0.9339444041252136 Top1 Train Accuracy:  65.91177368164062


 33%|███▎      | 1/3 [00:00<00:01,  1.09it/s]

Epoch: 80, Loss: 1.0018126964569092 Top1 Train Accuracy:  21.256510416666668


100%|██████████| 3/3 [00:01<00:00,  1.77it/s]


Epoch: 80, Loss: 0.9459376335144043 Top1 Train Accuracy:  43.636067708333336
Epoch: 80, Loss: 0.9321819543838501 Top1 Train Accuracy:  66.01807657877605


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 81, Loss: 1.0003389120101929 Top1 Train Accuracy:  21.272786458333332


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 81, Loss: 0.9444714784622192 Top1 Train Accuracy:  43.65234375
Epoch: 81, Loss: 0.9304460287094116 Top1 Train Accuracy:  66.0343526204427


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 82, Loss: 0.998887836933136 Top1 Train Accuracy:  21.272786458333332


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 82, Loss: 0.9430304169654846 Top1 Train Accuracy:  43.684895833333336
Epoch: 82, Loss: 0.9287354946136475 Top1 Train Accuracy:  66.06690470377605


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 83, Loss: 0.9974579215049744 Top1 Train Accuracy:  21.2890625


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 83, Loss: 0.9416136741638184 Top1 Train Accuracy:  43.684895833333336
Epoch: 83, Loss: 0.9270496368408203 Top1 Train Accuracy:  66.10377502441406


 33%|███▎      | 1/3 [00:00<00:01,  1.07it/s]

Epoch: 84, Loss: 0.9960492849349976 Top1 Train Accuracy:  21.2890625


100%|██████████| 3/3 [00:01<00:00,  1.74it/s]


Epoch: 84, Loss: 0.9402204155921936 Top1 Train Accuracy:  43.701171875
Epoch: 84, Loss: 0.9253876805305481 Top1 Train Accuracy:  66.12005106608073


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 85, Loss: 0.9946610927581787 Top1 Train Accuracy:  21.337890625


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 85, Loss: 0.9388498067855835 Top1 Train Accuracy:  43.782552083333336
Epoch: 85, Loss: 0.9237488508224487 Top1 Train Accuracy:  66.20143127441406


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 86, Loss: 0.9932925701141357 Top1 Train Accuracy:  21.354166666666668


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 86, Loss: 0.9375013709068298 Top1 Train Accuracy:  43.782552083333336
Epoch: 86, Loss: 0.9221327900886536 Top1 Train Accuracy:  66.23830159505208


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 87, Loss: 0.9919431209564209 Top1 Train Accuracy:  21.354166666666668


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 87, Loss: 0.936174213886261 Top1 Train Accuracy:  43.782552083333336
Epoch: 87, Loss: 0.9205381870269775 Top1 Train Accuracy:  66.27517700195312


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 88, Loss: 0.9906123280525208 Top1 Train Accuracy:  21.337890625


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 88, Loss: 0.9348675012588501 Top1 Train Accuracy:  43.815104166666664
Epoch: 88, Loss: 0.9189651608467102 Top1 Train Accuracy:  66.38147989908855


 33%|███▎      | 1/3 [00:00<00:01,  1.11it/s]

Epoch: 89, Loss: 0.9892996549606323 Top1 Train Accuracy:  21.354166666666668


100%|██████████| 3/3 [00:01<00:00,  1.80it/s]


Epoch: 89, Loss: 0.933580756187439 Top1 Train Accuracy:  43.831380208333336
Epoch: 89, Loss: 0.9174124598503113 Top1 Train Accuracy:  66.3977559407552


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 90, Loss: 0.9880045056343079 Top1 Train Accuracy:  21.354166666666668


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 90, Loss: 0.9323134422302246 Top1 Train Accuracy:  43.831380208333336
Epoch: 90, Loss: 0.9158798456192017 Top1 Train Accuracy:  66.47149658203125


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 91, Loss: 0.9867264628410339 Top1 Train Accuracy:  21.354166666666668


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 91, Loss: 0.9310649633407593 Top1 Train Accuracy:  43.880208333333336
Epoch: 91, Loss: 0.9143669605255127 Top1 Train Accuracy:  66.52032470703125


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 92, Loss: 0.985464870929718 Top1 Train Accuracy:  21.354166666666668


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 92, Loss: 0.9298346042633057 Top1 Train Accuracy:  43.929036458333336
Epoch: 92, Loss: 0.912872850894928 Top1 Train Accuracy:  66.53228251139323


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 93, Loss: 0.984219491481781 Top1 Train Accuracy:  21.38671875


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 93, Loss: 0.9286220073699951 Top1 Train Accuracy:  43.977864583333336
Epoch: 93, Loss: 0.911397397518158 Top1 Train Accuracy:  66.58111063639323


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 94, Loss: 0.9829896688461304 Top1 Train Accuracy:  21.38671875


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 94, Loss: 0.9274263978004456 Top1 Train Accuracy:  43.977864583333336
Epoch: 94, Loss: 0.909939706325531 Top1 Train Accuracy:  66.58111063639323


 33%|███▎      | 1/3 [00:00<00:01,  1.11it/s]

Epoch: 95, Loss: 0.9817752838134766 Top1 Train Accuracy:  21.38671875


100%|██████████| 3/3 [00:01<00:00,  1.79it/s]


Epoch: 95, Loss: 0.9262475967407227 Top1 Train Accuracy:  43.994140625
Epoch: 95, Loss: 0.9084993600845337 Top1 Train Accuracy:  66.63425699869792


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 96, Loss: 0.9805756211280823 Top1 Train Accuracy:  21.370442708333332


100%|██████████| 3/3 [00:01<00:00,  1.82it/s]


Epoch: 96, Loss: 0.9250850081443787 Top1 Train Accuracy:  44.010416666666664
Epoch: 96, Loss: 0.9070762991905212 Top1 Train Accuracy:  66.68740844726562


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 97, Loss: 0.9793905019760132 Top1 Train Accuracy:  21.38671875


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 97, Loss: 0.9239380359649658 Top1 Train Accuracy:  44.059244791666664
Epoch: 97, Loss: 0.9056696891784668 Top1 Train Accuracy:  66.77311197916667


 33%|███▎      | 1/3 [00:00<00:01,  1.14it/s]

Epoch: 98, Loss: 0.9782195687294006 Top1 Train Accuracy:  21.38671875


100%|██████████| 3/3 [00:01<00:00,  1.83it/s]


Epoch: 98, Loss: 0.922806441783905 Top1 Train Accuracy:  44.075520833333336
Epoch: 98, Loss: 0.9042794704437256 Top1 Train Accuracy:  66.78938802083333


 33%|███▎      | 1/3 [00:00<00:01,  1.13it/s]

Epoch: 99, Loss: 0.9770622253417969 Top1 Train Accuracy:  21.38671875


100%|██████████| 3/3 [00:01<00:00,  1.81it/s]


Epoch: 99, Loss: 0.9216897487640381 Top1 Train Accuracy:  44.059244791666664
Epoch: 99, Loss: 0.902904748916626 Top1 Train Accuracy:  66.77311197916667


100%|██████████| 4/4 [00:02<00:00,  1.61it/s]

Top1 Test Accuracy:  64.24687194824219
Top5 Test Accuracy:  98.04898071289062



