In [None]:
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tools.dataset import CIFAR10
from torch.utils.data import Subset, DataLoader
import math

from PIL import Image, ImageFilter

In [None]:
import torchvision.models as models

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Resnet20(nn.Module):
  def __init__(self):
    super(Resnet20, self).__init__()
    self.conv1 = nn.Conv2d(3, 16, 3, stride=1, padding=1)
    self.bn1 = nn.BatchNorm2d(16)
    # 32 feature size layers
    self.conv2 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
    self.bn2 = nn.BatchNorm2d(16)
    self.conv3 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
    self.bn3 = nn.BatchNorm2d(16)
    self.shortcut3 = nn.Sequential()
    self.conv4 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
    self.bn4 = nn.BatchNorm2d(16)
    self.conv5 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
    self.bn5 = nn.BatchNorm2d(16)
    self.shortcut5 = nn.Sequential()
    self.conv6 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
    self.bn6 = nn.BatchNorm2d(16)
    self.conv7 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
    self.bn7 = nn.BatchNorm2d(16)
    self.shortcut7 = nn.Sequential()
    # 16 feature size layers
    self.conv8 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
    self.bn8 = nn.BatchNorm2d(32)
    self.conv9 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
    self.bn9 = nn.BatchNorm2d(32)
    # shortcut of mismatch sizes
    self.shortcut9 = nn.Sequential(
        nn.Conv2d(16, 32, 1, stride=2, bias=False),
        nn.BatchNorm2d(32)
    )
    self.conv10 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
    self.bn10 = nn.BatchNorm2d(32)
    self.conv11 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
    self.bn11 = nn.BatchNorm2d(32)
    self.shortcut11 = nn.Sequential()
    self.conv12 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
    self.bn12 = nn.BatchNorm2d(32)
    self.conv13 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
    self.bn13 = nn.BatchNorm2d(32)
    self.shortcut13 = nn.Sequential()
    # 8 feature size layers
    self.conv14 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
    self.bn14 = nn.BatchNorm2d(64)
    self.conv15 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
    self.bn15 = nn.BatchNorm2d(64)
    # shortcut of mismatch sizes
    self.shortcut15 = nn.Sequential(
        nn.Conv2d(32, 64, 1, stride=2, bias=False),
        nn.BatchNorm2d(64)
    )
    self.conv16 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
    self.bn16 = nn.BatchNorm2d(64)
    self.conv17 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
    self.bn17 = nn.BatchNorm2d(64)
    self.shortcut17 = nn.Sequential()
    self.conv18 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
    self.bn18 = nn.BatchNorm2d(64)
    self.conv19 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
    self.bn19 = nn.BatchNorm2d(64)
    self.shortcut19 = nn.Sequential()
    self.avgPool = nn.AvgPool2d((8, 8))
    self.fc = nn.Linear(64, 10)


  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    # start 32 feature size layer
    firstInput = out.clone()
    out = F.relu(self.bn2(self.conv2(out)))
    out = self.bn3(self.conv3(out))
    out += self.shortcut3(firstInput)
    out = F.relu(out)
    secondInput = out.clone()
    out = F.relu(self.bn4(self.conv4(out)))
    out = self.bn5(self.conv5(out))
    out += self.shortcut5(secondInput)
    out = F.relu(out)
    thirdInput = out.clone()
    out = F.relu(self.bn6(self.conv6(out)))
    out = self.bn7(self.conv7(out))
    out += self.shortcut7(thirdInput)
    out = F.relu(out)
    # start 16 feature size layer
    fourthInput = out.clone()
    out = F.relu(self.bn8(self.conv8(out)))
    out = self.bn9(self.conv9(out))
    out += self.shortcut9(fourthInput)
    out = F.relu(out)
    fifthInput = out.clone()
    out = F.relu(self.bn10(self.conv10(out)))
    out = self.bn11(self.conv11(out))
    out += self.shortcut11(fifthInput)
    out = F.relu(out)
    sixthInput = out.clone()
    out = F.relu(self.bn12(self.conv12(out)))
    out = self.bn13(self.conv13(out))
    out += self.shortcut13(sixthInput)
    out = F.relu(out)
    # start 8 feature size layer
    seventhInput = out.clone()
    out = F.relu(self.bn14(self.conv14(out)))
    out = self.bn15(self.conv15(out))
    out += self.shortcut15(seventhInput)
    out = F.relu(out)
    eighthInput = out.clone()
    out = F.relu(self.bn16(self.conv16(out)))
    out = self.bn17(self.conv17(out))
    out += self.shortcut17(eighthInput)
    out = F.relu(out)
    ninthInput = out.clone()
    out = F.relu(self.bn18(self.conv18(out)))
    out = self.bn19(self.conv19(out))
    out += self.shortcut19(ninthInput)
    out = F.relu(out)
    # global avg pool, fc, and softmax
    #out = self.avgPool(out)
    out = torch.flatten(out, 1)
    #out = self.fc(out) # removed due to details in the paper
    return out


In [None]:
example_data = torch.randn(1,3,32,32)

net = Resnet20()
out = net.forward(example_data)
print(out.shape)

torch.Size([1, 4096])


In [None]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [None]:
def gaussian_blur(input_tensor, radius=1.0, device='cuda'):
    blurred_images = []

    for image_tensor in input_tensor:
        # Convert the image tensor to a PIL Image
        input_pil = transforms.ToPILImage()(image_tensor)

        # Apply Gaussian blur using the ImageFilter module
        blurred_pil = input_pil.filter(ImageFilter.GaussianBlur(radius=radius))

        # Convert the blurred PIL Image back to a tensor
        blurred_tensor = transforms.ToTensor()(blurred_pil)

        # Move the tensor to the specified device (GPU)
        blurred_tensor = blurred_tensor.to(device)

        blurred_images.append(blurred_tensor)

    return torch.stack(blurred_images)

In [None]:
def train_encoder(net, projection_head, train_loader, base_lr, device, model_checkpoint, momentum=0.9, total_epochs=100):
    warmup_epochs = 10
    optimizer = optim.SGD(list(net.parameters()) + list(projection_head.parameters()), lr=base_lr, weight_decay=1e-6,
                               momentum=momentum)

    # optimizer = LARS(base_optimizer, trust_coef=1e-3)

    color_distort = get_color_distortion(s=1.0)


    color_augmentation = color_distort

    crop_flip_augmentation = transforms.Compose([
        transforms.RandomResizedCrop(size=(32,32), scale=(0.08, 1)),
        transforms.RandomHorizontalFlip()
    ])

    # GAUSSIAN BLUR
    # make & use method


    for epoch in range(total_epochs):
        net.train()
        projection_head.train()
        # Adjust learning rate
        cosine_decay_learning_rate(epoch, total_epochs, warmup_epochs, base_lr, optimizer)

        # Wrap the train_loader with tqdm for a progress bar
        total_loss = 0
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{total_epochs}", leave=False)
        for batch_idx, (inputs, _) in progress_bar:  # targets are not used in SimCLR
            inputs = inputs.to(device)
            net.to(device)
            projection_head.to(device)

            # Apply augmentations and forward pass
            augment_1 = color_augmentation(inputs)
            output_1 = net(augment_1)
            z1 = projection_head(output_1)

            # augment_1 = gaussian_blur(inputs)
            # output_1 = net(augment_1)
            # z1 = projection_head(output_1)

            # augment_1 = crop_flip_augmentation(inputs)
            # output_1 = net(augment_1)
            # z1 = projection_head(output_1)

            augment_2 = crop_flip_augmentation(inputs)
            output_2 = net(augment_2)
            z2 = projection_head(output_2)

            temperature = 0.5
            loss = nt_xent_loss(z1, z2, temperature)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=total_loss / (batch_idx + 1))

        progress_bar.close()

        torch.save(net.state_dict(), model_checkpoint)

    return net

In [None]:
def initialize_encoder_head():
  # initialize the base encoder
  net = Resnet20()

  input_size = 4096
  hidden_size = 2048
  output_size = 128

  # Initialize the nonlinear projection head
  projection_head = MLP(input_size, hidden_size, output_size)

  return net, projection_head

In [None]:
def create_train_val(train_batch, val_batch):
  DATA_ROOT = "./data"
  TRAIN_BATCH_SIZE = train_batch
  VAL_BATCH_SIZE = val_batch

  transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  ])

  transform_val = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  ])
  # construct dataset
  train_set = CIFAR10(
      root=DATA_ROOT,
      mode='train',
      download=True,
      transform=transform_train
  )
  val_set = CIFAR10(
      root=DATA_ROOT,
      mode='val',
      download=True,
      transform=transform_val
  )

  # construct dataloader
  train_loader = DataLoader(
      train_set,
      batch_size=TRAIN_BATCH_SIZE,
      shuffle=True,
      num_workers=4
  )

  val_loader = DataLoader(
      val_set,
      batch_size=VAL_BATCH_SIZE,
      shuffle=True,
      num_workers=4
  )

  return train_loader, val_loader,

In [None]:
def create_train_finetune(train_batch):
  DATA_ROOT = "./data"
  TRAIN_BATCH_SIZE = train_batch
  transform_train = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
  ])
  # construct dataset
  train_set = CIFAR10(
      root=DATA_ROOT,
      mode='train',
      download=True,
      transform=transform_train
  )
  # construct dataloader
  train_loader = DataLoader(
      train_set,
      batch_size=TRAIN_BATCH_SIZE,
      shuffle=True,
      num_workers=4
  )

  # # Get the length of the dataset
  # num_data = len(train_set)

  # # Use only 10% of the data
  # subset_indices = list(range(0, int(0.01 * num_data)))
  # subset_train_set = Subset(train_set, subset_indices)

  # # Construct the dataloader
  # train_loader = DataLoader(
  #     subset_train_set,
  #     batch_size=TRAIN_BATCH_SIZE,
  #     shuffle=True,
  #     num_workers=4
  # )

  return train_loader

In [None]:
def create_test():
  DATA_ROOT = "./data"
  transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  ])

  test_set = CIFAR10(
      root=DATA_ROOT,
      mode='test',
      download=True,
      transform=transform_test
  )
  # construct dataloader
  test_loader = DataLoader(
      test_set,
      batch_size=1,
      shuffle=True,
      num_workers=4
  )
  return test_loader

In [None]:
def get_color_distortion(s=1.0):
# s is the strength of color distortion.
  color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
  rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
  rnd_gray = transforms.RandomGrayscale(p=0.2)
  color_distort = transforms.Compose([rnd_color_jitter,rnd_gray])
  return color_distort

In [None]:
def nt_xent_loss(z1, z2, temperature=0.5):
    N = z1.size(0)  # batch size

    # Normalize the representations along the feature dimensions
    z1_norm = F.normalize(z1, dim=1)
    z2_norm = F.normalize(z2, dim=1)

    # Concatenate the normalized representations
    representations = torch.cat([z1_norm, z2_norm], dim=0)

    # Compute similarity matrix
    similarity_matrix = torch.matmul(representations, representations.T) / temperature

    # Mask to exclude self-contrast cases
    mask = torch.eye(2*N, dtype=torch.bool).to(z1.device)
    similarity_matrix = similarity_matrix.masked_fill_(mask, -1e9)

    # Labels: For each example in z1, its positive pair is in z2, and vice versa
    labels = torch.cat([torch.arange(N, 2*N), torch.arange(N)]).to(z1.device)

    # Calculate the NT-Xent loss
    loss = F.cross_entropy(similarity_matrix, labels)

    return loss

In [None]:
def warmup_learning_rate(epoch, warmup_epochs, base_lr, optimizer):
    # Adjusts the learning rate for warmup phase """
    lr = base_lr * epoch / warmup_epochs
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def cosine_decay_learning_rate(epoch, total_epochs, warmup_epochs, base_lr, optimizer):
    # Adjusts the learning rate following a cosine curve after warmup """
    if epoch < warmup_epochs:
        warmup_learning_rate(epoch, warmup_epochs, base_lr, optimizer)
    else:
        lr = 0.5 * base_lr * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs)))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

In [None]:
def encode_for_classifier(train_loader, device, net):
    net.eval()
    features = []
    labels = []

    with torch.no_grad():  # No need for gradient computation
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs = inputs.to(device)
            output = net(inputs)
            features.append(output.cpu().detach())
            labels.append(targets.cpu().detach())
    X = torch.cat(features, dim=0).to(device)
    y = torch.cat(labels, dim=0).to(device)

    return X, y

In [None]:
# from TA
class LogisticRegression(nn.Module):
    def __init__(self, n_features, n_classes):
        super(LogisticRegression, self).__init__()

        self.model = nn.Linear(n_features, n_classes)

    def forward(self, x):
        return self.model(x)

In [None]:
def train_classifier_og(net, train_loader, val_loader, device, total_epochs):
    for param in net.parameters():
      param.requires_grad = False

    best_val_accuracy = 0

    # Prepare the data for L-BFGS optimization
    X, y = encode_for_classifier(train_loader, device, net)

    classifier = LogisticRegression(4096, 10).to(device)  # Adjust dimensions as needed
    classifier_optimizer = optim.LBFGS(classifier.parameters(), lr=1)

    for epoch in range(total_epochs):
        classifier.train()
        # Train Phase
        with tqdm(total=1, desc=f"Epoch {epoch+1}/{total_epochs} - Training", leave=False) as pbar_train:
            classifier_optimizer.zero_grad()
            outputs = classifier(X)
            loss = F.cross_entropy(outputs, y)
            loss.backward()
            classifier_optimizer.step()
            pbar_train.update(1)

        # Validation phase
        classifier.eval()
        correct, total = 0, 0
        with tqdm(val_loader, desc="Validating", leave=False) as pbar_val:
            for inputs, targets in pbar_val:
                inputs, targets = inputs.to(device), targets.to(device)
                features = net(inputs)
                outputs = classifier(features)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
                pbar_val.set_postfix(acc=(correct / total))

        val_accuracy = correct / total
        print(f'Epoch {epoch+1}/{total_epochs}: Validation Accuracy: {val_accuracy:.2f}')


In [None]:
def train_classifier(net, train_loader, val_loader, device, total_epochs):
    for param in net.parameters():
        param.requires_grad = False

    best_val_accuracy = 0

    # Prepare the data for L-BFGS optimization
    X, y = encode_for_classifier(train_loader, device, net)

    labeled_indices = torch.randperm(len(y))
    labeled_indices = labeled_indices[:int(0.1 * len(y))]

    X_labeled = X[labeled_indices]
    y_labeled = y[labeled_indices]

    classifier = LogisticRegression(4096, 10).to(device)  # Adjust dimensions as needed
    classifier_optimizer = optim.LBFGS(classifier.parameters(), lr=1)

    for epoch in range(total_epochs):
        classifier.train()
        # Train Phase
        with tqdm(total=1, desc=f"Epoch {epoch+1}/{total_epochs} - Training", leave=False) as pbar_train:
            classifier_optimizer.zero_grad()
            outputs = classifier(X_labeled)
            loss = F.cross_entropy(outputs, y_labeled)
            loss.backward()
            classifier_optimizer.step()
            pbar_train.update(1)

        # Validation phase
        classifier.eval()
        correct, total = 0, 0
        with tqdm(val_loader, desc="Validating", leave=False) as pbar_val:
            for inputs, targets in pbar_val:
                inputs, targets = inputs.to(device), targets.to(device)
                features = net(inputs)
                outputs = classifier(features)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
                pbar_val.set_postfix(acc=(correct / total))

        val_accuracy = correct / total
        print(f'Epoch {epoch+1}/{total_epochs}: Validation Accuracy: {val_accuracy:.2f}')

    return classifier

In [None]:
def fine_tune(net, train_loader, val_loader, device, total_epochs, model_checkpoint_net, model_checkpoint_clf):

  classifier = LogisticRegression(4096, 10).to(device)
  optimizer_clf = optim.SGD(list(classifier.parameters()), lr=0.005, weight_decay=1e-6,
                              momentum=0.9)

  optimizer_net = optim.SGD(list(net.parameters()), lr=0.01, weight_decay=1e-6,
                              momentum=0.9)

  best_val_acc = 0
  for epoch in range(total_epochs):
      net.train()
      classifier.train()
      # Adjust learning rate
      #cosine_decay_learning_rate(epoch, total_epochs, warmup_epochs, base_lr, optimizer)

      print("Epoch %d:", epoch)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0

      total_loss = 0

      progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{total_epochs}", leave=False)
      for batch_idx, (inputs, targets) in progress_bar:  # targets are not used in SimCLR
          inputs = inputs.to(device)
          targets = targets.to(device)
          net.to(device)
          classifier.to(device)

          output_1 = net(inputs)
          output_2 = classifier(output_1)
          loss = F.cross_entropy(output_2, targets)


          total_loss += loss.item()

          # zero the gradient
          optimizer_clf.zero_grad()
          optimizer_net.zero_grad()
          # backpropagation
          loss.backward()

          # apply gradient and update the weights
          optimizer_clf.step()
          optimizer_net.step()
          # count the number of correctly predicted samples in the current batch

          predicted_classes = output_2.max(1)[1] # get most probable class for each predictions (max within column(class))
          total_examples += targets.size(0)
          equality_comparison = predicted_classes.eq(targets) # compare if each prediciton is correct
          correct_examples += equality_comparison.sum().item() # add trues from above to num of correct predictions (looked this up online)


          #total_loss += loss.item()
          progress_bar.set_postfix(loss=total_loss / (batch_idx + 1))

      progress_bar.close()

      avg_loss = total_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))

      # Validate on the validation dataset
      #######################
      # your code here
      # switch to eval mode
      net.eval()
      classifier.eval()

      #######################

      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0

      val_loss = 0 # again, track the validation loss if you want

          # disable gradient during validation, which can save GPU memory
      with torch.no_grad():
          for batch_idx, (inputs, targets) in enumerate(val_loader):
              ####################################
              # your code here
              # copy inputs to device
              inputs = inputs.to(device)
              targets = targets.to(device)

              # compute the output and loss
              output_1 = net(inputs)
              output_2 = classifier(output_1)
              loss = F.cross_entropy(output_2, targets)
              val_loss += loss.item()

              # count the number of correctly predicted samples in the current batch
              predicted_classes = output_2.max(1)[1] # get most probable class for each prediction
              total_examples += targets.size(0)
              equality_comparison = predicted_classes.eq(targets) # compare if each prediciton is correct
              correct_examples += equality_comparison.sum().item() # add trues from above to num of correct predictions

              ####################################

      avg_loss = val_loss / len(val_loader)
      avg_acc = correct_examples / total_examples
      print("Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc))

      # save the model checkpoint
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc

          torch.save(net.state_dict(), model_checkpoint_net)
          torch.save(classifier.state_dict(), model_checkpoint_clf)

      print('')


  print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")


In [None]:
def fine_tune_load(classifier, net, train_loader, val_loader, device, total_epochs, model_checkpoint_net, model_checkpoint_clf):
  optimizer_clf = optim.SGD(list(classifier.parameters()), lr=0.005, weight_decay=1e-6,
                              momentum=0.9)

  optimizer_net = optim.SGD(list(net.parameters()), lr=0.01, weight_decay=1e-6,
                              momentum=0.9)

  best_val_acc = 0
  for epoch in range(total_epochs):
      net.train()
      classifier.train()
      # Adjust learning rate
      #cosine_decay_learning_rate(epoch, total_epochs, warmup_epochs, base_lr, optimizer)

      print("Epoch %d:", epoch)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0

      total_loss = 0

      progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{total_epochs}", leave=False)
      for batch_idx, (inputs, targets) in progress_bar:  # targets are not used in SimCLR
          inputs = inputs.to(device)
          targets = targets.to(device)
          net.to(device)
          classifier.to(device)

          output_1 = net(inputs)
          output_2 = classifier(output_1)
          loss = F.cross_entropy(output_2, targets)


          total_loss += loss.item()

          # zero the gradient
          optimizer_clf.zero_grad()
          optimizer_net.zero_grad()
          # backpropagation
          loss.backward()

          # apply gradient and update the weights
          optimizer_clf.step()
          optimizer_net.step()
          # count the number of correctly predicted samples in the current batch

          predicted_classes = output_2.max(1)[1] # get most probable class for each predictions (max within column(class))
          total_examples += targets.size(0)
          equality_comparison = predicted_classes.eq(targets) # compare if each prediciton is correct
          correct_examples += equality_comparison.sum().item() # add trues from above to num of correct predictions (looked this up online)


          #total_loss += loss.item()
          progress_bar.set_postfix(loss=total_loss / (batch_idx + 1))

      progress_bar.close()

      avg_loss = total_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))

      # Validate on the validation dataset
      #######################
      # your code here
      # switch to eval mode
      net.eval()
      classifier.eval()

      #######################

      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0

      val_loss = 0 # again, track the validation loss if you want

          # disable gradient during validation, which can save GPU memory
      with torch.no_grad():
          for batch_idx, (inputs, targets) in enumerate(val_loader):
              ####################################
              # your code here
              # copy inputs to device
              inputs = inputs.to(device)
              targets = targets.to(device)

              # compute the output and loss
              output_1 = net(inputs)
              output_2 = classifier(output_1)
              loss = F.cross_entropy(output_2, targets)
              val_loss += loss.item()

              # count the number of correctly predicted samples in the current batch
              predicted_classes = output_2.max(1)[1] # get most probable class for each prediction
              total_examples += targets.size(0)
              equality_comparison = predicted_classes.eq(targets) # compare if each prediciton is correct
              correct_examples += equality_comparison.sum().item() # add trues from above to num of correct predictions

              ####################################

      avg_loss = val_loss / len(val_loader)
      avg_acc = correct_examples / total_examples
      print("Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc))

      # save the model checkpoint
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc

          torch.save(net.state_dict(), model_checkpoint_net)
          torch.save(classifier.state_dict(), model_checkpoint_clf)

      print('')


  print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")

In [None]:
def test_classifier(net, classifier, test_loader, device):
    net.eval()  # Ensure the network is in evaluation mode
    classifier.eval()  # Ensure the classifier is in evaluation mode

    correct, total = 0, 0
    with torch.no_grad():  # No gradients needed
        for _ , (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            # Extract features and pass through the classifier
            features = net(inputs)
            outputs = classifier(features)
            _, predicted = torch.max(outputs.data, 1)

            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    test_accuracy = correct / total
    print(f'Test Accuracy: {test_accuracy:.2f}')

    return test_accuracy

In [None]:
# starting the pipeline
# to generate the histogram: we are going to keep temperature at 0.5
# learning rates at (0.5, 1.0, 1.5)
# momentum will be kept at 0.9
# batch sizes can include (256, 512, 1024, 2048, 4096)

net, projection_head = initialize_encoder_head()
train_batch = 512 # CHANGE
val_batch = 100
base_lr = 1.0
momentum = 0.9
total_epochs = 10 # CHANGE
fine_tune_epochs = 10
model_checkpoint_encoder = 'encoder_b512_lr1p0_e10_e100'
model_checkpoint_net = 'net_b512_lr1p0_e10_e100'
model_checkpoint_clf = 'clf_b512_lr1p0_e10_e100'

train_loader, val_loader = create_train_val(train_batch, val_batch)
finetune_loader = create_train_finetune(train_batch)
net = train_encoder(net, projection_head, train_loader, base_lr, device, model_checkpoint_encoder, momentum, total_epochs)
fine_tune(net, finetune_loader, val_loader, device, fine_tune_epochs, model_checkpoint_net, model_checkpoint_clf)

Using downloaded and verified file: ./data/cifar10_trainval_F22.zip
Extracting ./data/cifar10_trainval_F22.zip to ./data
Files already downloaded and verified
Using downloaded and verified file: ./data/cifar10_trainval_F22.zip
Extracting ./data/cifar10_trainval_F22.zip to ./data
Files already downloaded and verified
Using downloaded and verified file: ./data/cifar10_trainval_F22.zip
Extracting ./data/cifar10_trainval_F22.zip to ./data
Files already downloaded and verified


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{total_epochs}", leave=False)


Epoch 1/10:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 2/10:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 3/10:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 4/10:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 5/10:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 6/10:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 7/10:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 8/10:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 9/10:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 10/10:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch %d: 0


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{total_epochs}", leave=False)


Epoch 1/10:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 2.4407, Training accuracy: 0.3066
Validation loss: 1.8108, Validation accuracy: 0.3868

Epoch %d: 1


Epoch 2/10:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.7246, Training accuracy: 0.3878
Validation loss: 1.6327, Validation accuracy: 0.4394

Epoch %d: 2


Epoch 3/10:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.6439, Training accuracy: 0.4138
Validation loss: 1.5245, Validation accuracy: 0.4600

Epoch %d: 3


Epoch 4/10:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.5498, Training accuracy: 0.4407
Validation loss: 1.4233, Validation accuracy: 0.4938

Epoch %d: 4


Epoch 5/10:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.5149, Training accuracy: 0.4528
Validation loss: 1.4066, Validation accuracy: 0.5064

Epoch %d: 5


Epoch 6/10:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.4958, Training accuracy: 0.4626
Validation loss: 1.3968, Validation accuracy: 0.5052

Epoch %d: 6


Epoch 7/10:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.4406, Training accuracy: 0.4784
Validation loss: 1.3320, Validation accuracy: 0.5264

Epoch %d: 7


Epoch 8/10:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.4069, Training accuracy: 0.4919
Validation loss: 1.3181, Validation accuracy: 0.5312

Epoch %d: 8


Epoch 9/10:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.3768, Training accuracy: 0.5047
Validation loss: 1.2979, Validation accuracy: 0.5372

Epoch %d: 9


Epoch 10/10:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.3571, Training accuracy: 0.5106
Validation loss: 1.2909, Validation accuracy: 0.5420

==> Optimization finished! Best validation accuracy: 0.5420


In [None]:
torch.save(net.state_dict(), "net_b512_lr1p0_e50_e100")

In [None]:
## IF ALREADY HAVE TRAINED ENCODER
net = Resnet20().to(device)
net.load_state_dict(torch.load("encoder_b512_lr1p0_e250_e100"))

train_batch = 512
val_batch = 250 # CHANGE
classify_total_epochs = 100

model_checkpoint_net = 'net_b512_lr1p0_e250_e100'
model_checkpoint_clf = 'clf_b512_lr1p0_e250_e100'

train_loader, val_loader = create_train_val(train_batch, val_batch)
finetune_loader = create_train_finetune(train_batch)

fine_tune(net, finetune_loader, val_loader, device, fine_tune_epochs, model_checkpoint_net, model_checkpoint_clf)

Using downloaded and verified file: ./data/cifar10_trainval_F22.zip
Extracting ./data/cifar10_trainval_F22.zip to ./data
Files already downloaded and verified
Using downloaded and verified file: ./data/cifar10_trainval_F22.zip
Extracting ./data/cifar10_trainval_F22.zip to ./data
Files already downloaded and verified
Using downloaded and verified file: ./data/cifar10_trainval_F22.zip
Extracting ./data/cifar10_trainval_F22.zip to ./data
Files already downloaded and verified
Epoch %d: 0


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{total_epochs}", leave=False)


Epoch 1/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.4865, Training accuracy: 0.5073
Validation loss: 1.1958, Validation accuracy: 0.5906

Epoch %d: 1


Epoch 2/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.2005, Training accuracy: 0.5792
Validation loss: 1.1392, Validation accuracy: 0.6054

Epoch %d: 2


Epoch 3/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.1412, Training accuracy: 0.5963
Validation loss: 1.0782, Validation accuracy: 0.6224

Epoch %d: 3


Epoch 4/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.0830, Training accuracy: 0.6188
Validation loss: 1.0068, Validation accuracy: 0.6482

Epoch %d: 4


Epoch 5/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.0627, Training accuracy: 0.6245
Validation loss: 1.0398, Validation accuracy: 0.6392

Epoch %d: 5


Epoch 6/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 1.0188, Training accuracy: 0.6386
Validation loss: 0.9534, Validation accuracy: 0.6634

Epoch %d: 6


Epoch 7/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.9877, Training accuracy: 0.6461
Validation loss: 0.9767, Validation accuracy: 0.6616

Epoch %d: 7


Epoch 8/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.9482, Training accuracy: 0.6633
Validation loss: 0.9332, Validation accuracy: 0.6746

Epoch %d: 8


Epoch 9/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.9385, Training accuracy: 0.6676
Validation loss: 0.9509, Validation accuracy: 0.6714

Epoch %d: 9


Epoch 10/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.9111, Training accuracy: 0.6747
Validation loss: 0.9164, Validation accuracy: 0.6868

Epoch %d: 10


Epoch 11/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.8952, Training accuracy: 0.6794
Validation loss: 0.9139, Validation accuracy: 0.6884

Epoch %d: 11


Epoch 12/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.8905, Training accuracy: 0.6825
Validation loss: 0.8827, Validation accuracy: 0.6930

Epoch %d: 12


Epoch 13/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.8711, Training accuracy: 0.6879
Validation loss: 0.9087, Validation accuracy: 0.6830

Epoch %d: 13


Epoch 14/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.8528, Training accuracy: 0.6940
Validation loss: 0.8924, Validation accuracy: 0.6974

Epoch %d: 14


Epoch 15/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.8430, Training accuracy: 0.6990
Validation loss: 0.8677, Validation accuracy: 0.7030

Epoch %d: 15


Epoch 16/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.8360, Training accuracy: 0.7025
Validation loss: 0.8570, Validation accuracy: 0.7072

Epoch %d: 16


Epoch 17/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.8236, Training accuracy: 0.7078
Validation loss: 0.8687, Validation accuracy: 0.7016

Epoch %d: 17


Epoch 18/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.8122, Training accuracy: 0.7127
Validation loss: 0.8423, Validation accuracy: 0.7146

Epoch %d: 18


Epoch 19/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7994, Training accuracy: 0.7131
Validation loss: 0.8528, Validation accuracy: 0.7140

Epoch %d: 19


Epoch 20/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7996, Training accuracy: 0.7160
Validation loss: 0.8291, Validation accuracy: 0.7162

Epoch %d: 20


Epoch 21/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7939, Training accuracy: 0.7152
Validation loss: 0.8256, Validation accuracy: 0.7158

Epoch %d: 21


Epoch 22/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7799, Training accuracy: 0.7233
Validation loss: 0.8131, Validation accuracy: 0.7206

Epoch %d: 22


Epoch 23/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7684, Training accuracy: 0.7263
Validation loss: 0.8106, Validation accuracy: 0.7182

Epoch %d: 23


Epoch 24/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7695, Training accuracy: 0.7260
Validation loss: 0.8134, Validation accuracy: 0.7212

Epoch %d: 24


Epoch 25/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7633, Training accuracy: 0.7277
Validation loss: 0.8099, Validation accuracy: 0.7230

Epoch %d: 25


Epoch 26/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7576, Training accuracy: 0.7298
Validation loss: 0.8082, Validation accuracy: 0.7268

Epoch %d: 26


Epoch 27/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7415, Training accuracy: 0.7359
Validation loss: 0.7920, Validation accuracy: 0.7324

Epoch %d: 27


Epoch 28/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7407, Training accuracy: 0.7356
Validation loss: 0.8047, Validation accuracy: 0.7272

Epoch %d: 28


Epoch 29/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7365, Training accuracy: 0.7367
Validation loss: 0.7801, Validation accuracy: 0.7384

Epoch %d: 29


Epoch 30/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7321, Training accuracy: 0.7385
Validation loss: 0.7792, Validation accuracy: 0.7338

Epoch %d: 30


Epoch 31/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7275, Training accuracy: 0.7425
Validation loss: 0.7865, Validation accuracy: 0.7320

Epoch %d: 31


Epoch 32/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7169, Training accuracy: 0.7427
Validation loss: 0.7795, Validation accuracy: 0.7392

Epoch %d: 32


Epoch 33/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7121, Training accuracy: 0.7474
Validation loss: 0.7678, Validation accuracy: 0.7364

Epoch %d: 33


Epoch 34/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7148, Training accuracy: 0.7442
Validation loss: 0.7655, Validation accuracy: 0.7394

Epoch %d: 34


Epoch 35/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.7064, Training accuracy: 0.7473
Validation loss: 0.7672, Validation accuracy: 0.7346

Epoch %d: 35


Epoch 36/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6927, Training accuracy: 0.7541
Validation loss: 0.7615, Validation accuracy: 0.7424

Epoch %d: 36


Epoch 37/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6904, Training accuracy: 0.7546
Validation loss: 0.7451, Validation accuracy: 0.7498

Epoch %d: 37


Epoch 38/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6915, Training accuracy: 0.7532
Validation loss: 0.7616, Validation accuracy: 0.7362

Epoch %d: 38


Epoch 39/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6862, Training accuracy: 0.7546
Validation loss: 0.7492, Validation accuracy: 0.7512

Epoch %d: 39


Epoch 40/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6797, Training accuracy: 0.7578
Validation loss: 0.7442, Validation accuracy: 0.7496

Epoch %d: 40


Epoch 41/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6801, Training accuracy: 0.7591
Validation loss: 0.7433, Validation accuracy: 0.7462

Epoch %d: 41


Epoch 42/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6756, Training accuracy: 0.7585
Validation loss: 0.7268, Validation accuracy: 0.7544

Epoch %d: 42


Epoch 43/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6655, Training accuracy: 0.7618
Validation loss: 0.7372, Validation accuracy: 0.7520

Epoch %d: 43


Epoch 44/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6687, Training accuracy: 0.7613
Validation loss: 0.7423, Validation accuracy: 0.7482

Epoch %d: 44


Epoch 45/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6647, Training accuracy: 0.7640
Validation loss: 0.7278, Validation accuracy: 0.7564

Epoch %d: 45


Epoch 46/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6597, Training accuracy: 0.7647
Validation loss: 0.7200, Validation accuracy: 0.7526

Epoch %d: 46


Epoch 47/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6518, Training accuracy: 0.7676
Validation loss: 0.7379, Validation accuracy: 0.7508

Epoch %d: 47


Epoch 48/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6520, Training accuracy: 0.7672
Validation loss: 0.7182, Validation accuracy: 0.7534

Epoch %d: 48


Epoch 49/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6506, Training accuracy: 0.7681
Validation loss: 0.7226, Validation accuracy: 0.7552

Epoch %d: 49


Epoch 50/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6398, Training accuracy: 0.7732
Validation loss: 0.7291, Validation accuracy: 0.7566

Epoch %d: 50


Epoch 51/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6380, Training accuracy: 0.7740
Validation loss: 0.7164, Validation accuracy: 0.7592

Epoch %d: 51


Epoch 52/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6343, Training accuracy: 0.7731
Validation loss: 0.7157, Validation accuracy: 0.7636

Epoch %d: 52


Epoch 53/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6337, Training accuracy: 0.7737
Validation loss: 0.7053, Validation accuracy: 0.7622

Epoch %d: 53


Epoch 54/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6259, Training accuracy: 0.7776
Validation loss: 0.7099, Validation accuracy: 0.7602

Epoch %d: 54


Epoch 55/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6267, Training accuracy: 0.7754
Validation loss: 0.7077, Validation accuracy: 0.7604

Epoch %d: 55


Epoch 56/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6279, Training accuracy: 0.7764
Validation loss: 0.7101, Validation accuracy: 0.7664

Epoch %d: 56


Epoch 57/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6179, Training accuracy: 0.7801
Validation loss: 0.6901, Validation accuracy: 0.7724

Epoch %d: 57


Epoch 58/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6162, Training accuracy: 0.7809
Validation loss: 0.7036, Validation accuracy: 0.7658

Epoch %d: 58


Epoch 59/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6125, Training accuracy: 0.7809
Validation loss: 0.7059, Validation accuracy: 0.7602

Epoch %d: 59


Epoch 60/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6091, Training accuracy: 0.7834
Validation loss: 0.7039, Validation accuracy: 0.7672

Epoch %d: 60


Epoch 61/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6059, Training accuracy: 0.7845
Validation loss: 0.6903, Validation accuracy: 0.7700

Epoch %d: 61


Epoch 62/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6104, Training accuracy: 0.7844
Validation loss: 0.6908, Validation accuracy: 0.7664

Epoch %d: 62


Epoch 63/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6018, Training accuracy: 0.7859
Validation loss: 0.6885, Validation accuracy: 0.7736

Epoch %d: 63


Epoch 64/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.6009, Training accuracy: 0.7856
Validation loss: 0.6777, Validation accuracy: 0.7768

Epoch %d: 64


Epoch 65/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5956, Training accuracy: 0.7870
Validation loss: 0.6794, Validation accuracy: 0.7746

Epoch %d: 65


Epoch 66/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5961, Training accuracy: 0.7876
Validation loss: 0.6804, Validation accuracy: 0.7744

Epoch %d: 66


Epoch 67/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5895, Training accuracy: 0.7907
Validation loss: 0.6756, Validation accuracy: 0.7746

Epoch %d: 67


Epoch 68/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5864, Training accuracy: 0.7909
Validation loss: 0.6692, Validation accuracy: 0.7770

Epoch %d: 68


Epoch 69/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5873, Training accuracy: 0.7902
Validation loss: 0.6777, Validation accuracy: 0.7778

Epoch %d: 69


Epoch 70/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5781, Training accuracy: 0.7951
Validation loss: 0.6879, Validation accuracy: 0.7700

Epoch %d: 70


Epoch 71/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5822, Training accuracy: 0.7924
Validation loss: 0.6707, Validation accuracy: 0.7808

Epoch %d: 71


Epoch 72/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5756, Training accuracy: 0.7946
Validation loss: 0.6744, Validation accuracy: 0.7796

Epoch %d: 72


Epoch 73/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5741, Training accuracy: 0.7969
Validation loss: 0.6859, Validation accuracy: 0.7742

Epoch %d: 73


Epoch 74/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5729, Training accuracy: 0.7944
Validation loss: 0.6809, Validation accuracy: 0.7748

Epoch %d: 74


Epoch 75/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5714, Training accuracy: 0.7979
Validation loss: 0.6739, Validation accuracy: 0.7786

Epoch %d: 75


Epoch 76/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5647, Training accuracy: 0.7993
Validation loss: 0.6639, Validation accuracy: 0.7834

Epoch %d: 76


Epoch 77/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5612, Training accuracy: 0.8005
Validation loss: 0.6684, Validation accuracy: 0.7806

Epoch %d: 77


Epoch 78/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5603, Training accuracy: 0.7996
Validation loss: 0.6602, Validation accuracy: 0.7810

Epoch %d: 78


Epoch 79/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5605, Training accuracy: 0.8004
Validation loss: 0.6813, Validation accuracy: 0.7822

Epoch %d: 79


Epoch 80/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5566, Training accuracy: 0.8023
Validation loss: 0.6526, Validation accuracy: 0.7918

Epoch %d: 80


Epoch 81/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5596, Training accuracy: 0.8004
Validation loss: 0.6541, Validation accuracy: 0.7838

Epoch %d: 81


Epoch 82/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5569, Training accuracy: 0.8032
Validation loss: 0.6536, Validation accuracy: 0.7874

Epoch %d: 82


Epoch 83/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5484, Training accuracy: 0.8045
Validation loss: 0.6601, Validation accuracy: 0.7836

Epoch %d: 83


Epoch 84/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5463, Training accuracy: 0.8063
Validation loss: 0.6560, Validation accuracy: 0.7860

Epoch %d: 84


Epoch 85/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5473, Training accuracy: 0.8032
Validation loss: 0.6467, Validation accuracy: 0.7914

Epoch %d: 85


Epoch 86/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5420, Training accuracy: 0.8062
Validation loss: 0.6515, Validation accuracy: 0.7904

Epoch %d: 86


Epoch 87/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5407, Training accuracy: 0.8072
Validation loss: 0.6637, Validation accuracy: 0.7878

Epoch %d: 87


Epoch 88/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5401, Training accuracy: 0.8090
Validation loss: 0.6555, Validation accuracy: 0.7898

Epoch %d: 88


Epoch 89/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5343, Training accuracy: 0.8078
Validation loss: 0.6481, Validation accuracy: 0.7900

Epoch %d: 89


Epoch 90/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5337, Training accuracy: 0.8099
Validation loss: 0.6542, Validation accuracy: 0.7880

Epoch %d: 90


Epoch 91/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5322, Training accuracy: 0.8099
Validation loss: 0.6676, Validation accuracy: 0.7850

Epoch %d: 91


Epoch 92/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5335, Training accuracy: 0.8084
Validation loss: 0.6685, Validation accuracy: 0.7860

Epoch %d: 92


Epoch 93/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5339, Training accuracy: 0.8093
Validation loss: 0.6464, Validation accuracy: 0.7898

Epoch %d: 93


Epoch 94/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5246, Training accuracy: 0.8120
Validation loss: 0.6432, Validation accuracy: 0.7922

Epoch %d: 94


Epoch 95/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5234, Training accuracy: 0.8121
Validation loss: 0.6395, Validation accuracy: 0.7898

Epoch %d: 95


Epoch 96/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5205, Training accuracy: 0.8147
Validation loss: 0.6393, Validation accuracy: 0.7924

Epoch %d: 96


Epoch 97/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5172, Training accuracy: 0.8156
Validation loss: 0.6473, Validation accuracy: 0.7916

Epoch %d: 97


Epoch 98/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5163, Training accuracy: 0.8151
Validation loss: 0.6420, Validation accuracy: 0.7954

Epoch %d: 98


Epoch 99/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5127, Training accuracy: 0.8175
Validation loss: 0.6365, Validation accuracy: 0.7966

Epoch %d: 99


Epoch 100/100:   0%|          | 0/88 [00:00<?, ?it/s]

Training loss: 0.5081, Training accuracy: 0.8166
Validation loss: 0.6281, Validation accuracy: 0.7966

==> Optimization finished! Best validation accuracy: 0.7966


In [None]:
torch.save(net.state_dict(), "net_b512_lr1p0_e250_e100")