In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
from torch import device
from tqdm.notebook import tqdm
torch.manual_seed(47)


<torch._C.Generator at 0x7eabe7f04330>

In [None]:
class ResUnit(nn.Module):
    def __init__(self, p , stride,exp=1):
        super(ResUnit, self).__init__()
        self.c1 = nn.Conv2d(p, p*exp, kernel_size=3, stride=stride, padding=1, bias=False)
        self.b1 = nn.BatchNorm2d(p*exp)
        self.relu = nn.ReLU()
        self.c2 = nn.Conv2d( p*exp,  p*exp , kernel_size=3, stride=1, padding=1, bias=False)
        self.b2 = nn.BatchNorm2d( p*exp)
        self.drp = nn.Dropout(0.2)
        self.relu2 = nn.ReLU()

        ## to ensure same dimension
        self.residual = nn.Sequential()
        if stride != 1 or exp!=1:
            self.residual = nn.Sequential(
                nn.Conv2d( p, p*exp, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d( p*exp )
            )

    def forward(self, x):
        out = self.relu(self.b1(self.c1(x)))
        out = self.b2(self.c2(out))
        out = self.drp(out)
        res = self.residual(x)
        # print("x ",x.shape)
        # print("out ",out.shape)
        # print("res ",res.shape)
        out = self.relu2(out+res)
        return out

```
ResNet(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (conv2_x): Sequential(
    (0): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
  )
  (conv3_x): Sequential(
    (0): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
  )
  (conv4_x): Sequential(
    (0): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
  )
  (conv5_x): Sequential(
    (0): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
  )
  (avg_pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=100, bias=True)
)


In [None]:
class ResNet18(nn.Module):
    def __init__(self,p,expansion,num_classes=100):
        super(ResNet18, self).__init__()
        #input size is 32x32x3
        #first layer
        # to go from 32*32 to
        p=64
        l=[
            nn.Conv2d(3, p, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(p),
            nn.ReLU(inplace=True),
          ]
        l.append(ResUnit(p,  1))
        l.append(ResUnit(p,  1))

        l.append(ResUnit(p,  2,expansion))
        p*=expansion
        l.append(ResUnit(p,  1))

        l.append(ResUnit(p,  2,expansion))
        p*=expansion
        l.append(ResUnit(p,  1))

        l.append(ResUnit(p,  2,expansion))
        p*=expansion
        l.append(ResUnit(p,  1))


        # self.r3 = nn.Sequential(
        l.extend([
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(p,num_classes),
        ])
        # )
        self.model = nn.Sequential(*l)
    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
transf = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [None]:
data = torchvision.datasets.CIFAR100(
    root="./data", train=True, download=True, transform=transf
)

100%|██████████| 169M/169M [00:06<00:00, 25.1MB/s]


In [None]:
#hyperparameters
p=64
lr =0.01
batch_size = 32
epochs = 200
exp = 2
weight_decay = 5e-4
crit = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
resnet = ResNet18(p,2,100).to(device)
opti = torch.optim.SGD(resnet.parameters(), lr=lr, weight_decay=1e-6,momentum = 0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opti, T_max=epochs)


In [None]:
train_size = int(0.9* len(data))
test_size = int(0.1*len(data))
train_data, val_data = torch.utils.data.random_split(data, [train_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
print(resnet)

ResNet18(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ResUnit(
      (c1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (b1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (c2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (b2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drp): Dropout(p=0.2, inplace=False)
      (relu2): ReLU()
      (residual): Sequential()
    )
    (4): ResUnit(
      (c1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (b1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (c2): Conv2d(64, 64, kernel_size

In [None]:
losses = []
accuracies = []
val_losses = []
val_accuracies = []

for epoch in range(epochs):
    i=0
    correct, total = 0, 0
    for inputs, labels in tqdm(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = resnet(inputs)

        correct += (outputs.argmax(dim=1) == labels).sum().item()
        total += labels.size(0)

        loss = crit(outputs, labels)

        opti.zero_grad()

        loss.backward()
        opti.step()
        scheduler.step()
        losses.append(loss.item())

        i+=1
    accuracies.append(float(correct)/float(total))
    print(f"Epoch {epoch+1}/{epochs}, Loss: {np.mean(losses[-i:])}, Accuracy: {accuracies[-1]}")


    resnet.eval()
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for val_in,val_l in val_dataloader:
            val_in = val_in.to(device)
            val_l = val_l.to(device)

            val_out = resnet(val_in)

            val_loss = crit(val_out, val_l)

            val_correct += (val_out.argmax(dim=1) == val_l).sum().item()
            val_total += val_l.size(0)
    val_losses.append(val_loss.item())
    val_accuracies.append(float(val_correct)/float(val_total))
    print(f"Validation Loss: {val_loss.item()}, Validation Accuracy: {val_accuracies[-1]}")
    resnet.train()


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 1/200, Loss: 3.843861480596308, Accuracy: 0.11095555555555556
Validation Loss: 3.9278852939605713, Validation Accuracy: 0.2038


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 2/200, Loss: 3.1683590601925826, Accuracy: 0.22
Validation Loss: 4.039178371429443, Validation Accuracy: 0.2224


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 3/200, Loss: 2.70920194076492, Accuracy: 0.3050888888888889
Validation Loss: 3.5823283195495605, Validation Accuracy: 0.3988


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 4/200, Loss: 2.3524857917913073, Accuracy: 0.3801333333333333
Validation Loss: 3.8319931030273438, Validation Accuracy: 0.341


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 5/200, Loss: 2.0564326296991378, Accuracy: 0.44324444444444444
Validation Loss: 3.7247538566589355, Validation Accuracy: 0.4932


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 6/200, Loss: 1.8407922047796026, Accuracy: 0.4925777777777778
Validation Loss: 4.419766902923584, Validation Accuracy: 0.4178


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 7/200, Loss: 1.630765006274472, Accuracy: 0.5444
Validation Loss: 3.3765697479248047, Validation Accuracy: 0.5368


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 8/200, Loss: 1.4757409768063885, Accuracy: 0.5813111111111111
Validation Loss: 4.280612468719482, Validation Accuracy: 0.4752


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 9/200, Loss: 1.2906547738206073, Accuracy: 0.6311333333333333
Validation Loss: 3.8930578231811523, Validation Accuracy: 0.5656


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 10/200, Loss: 1.1779774583521343, Accuracy: 0.6573555555555556
Validation Loss: 4.446514129638672, Validation Accuracy: 0.5006


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 11/200, Loss: 1.0048569947502337, Accuracy: 0.7042444444444445
Validation Loss: 3.2686607837677, Validation Accuracy: 0.5794


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 12/200, Loss: 0.9164691980247728, Accuracy: 0.7272666666666666
Validation Loss: 3.5806102752685547, Validation Accuracy: 0.5346


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 13/200, Loss: 0.7543816464796249, Accuracy: 0.7752666666666667
Validation Loss: 4.990780353546143, Validation Accuracy: 0.5776


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 14/200, Loss: 0.6664821376763258, Accuracy: 0.7989555555555555
Validation Loss: 4.84139347076416, Validation Accuracy: 0.558


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 15/200, Loss: 0.5409839905549553, Accuracy: 0.8354
Validation Loss: 4.901212692260742, Validation Accuracy: 0.577


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 16/200, Loss: 0.47662708922616964, Accuracy: 0.8545111111111111
Validation Loss: 4.455333232879639, Validation Accuracy: 0.551


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 17/200, Loss: 0.37309482306580893, Accuracy: 0.8853111111111112
Validation Loss: 4.329877853393555, Validation Accuracy: 0.5716


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 18/200, Loss: 0.3282747992023815, Accuracy: 0.8991111111111111
Validation Loss: 4.706404209136963, Validation Accuracy: 0.5676


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 19/200, Loss: 0.26792152946141057, Accuracy: 0.9178444444444445
Validation Loss: 4.882611274719238, Validation Accuracy: 0.563


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 20/200, Loss: 0.2369895437765435, Accuracy: 0.9275333333333333
Validation Loss: 5.310894012451172, Validation Accuracy: 0.5756


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 21/200, Loss: 0.18862324175719428, Accuracy: 0.9428222222222222
Validation Loss: 6.082357406616211, Validation Accuracy: 0.5672


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 22/200, Loss: 0.15666994282760008, Accuracy: 0.9538
Validation Loss: 5.826814651489258, Validation Accuracy: 0.5738


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 23/200, Loss: 0.12516309243614315, Accuracy: 0.9637555555555556
Validation Loss: 6.787991523742676, Validation Accuracy: 0.564


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 24/200, Loss: 0.1225474847536256, Accuracy: 0.9628
Validation Loss: 5.250673770904541, Validation Accuracy: 0.5714


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 25/200, Loss: 0.09789028731336938, Accuracy: 0.9709777777777778
Validation Loss: 6.501191139221191, Validation Accuracy: 0.573


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 26/200, Loss: 0.07867494731275702, Accuracy: 0.9783555555555555
Validation Loss: 6.011368751525879, Validation Accuracy: 0.58


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 27/200, Loss: 0.06380943599946044, Accuracy: 0.9832666666666666
Validation Loss: 5.980834007263184, Validation Accuracy: 0.5768


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 28/200, Loss: 0.07109413729873357, Accuracy: 0.9808888888888889
Validation Loss: 6.741372585296631, Validation Accuracy: 0.582


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 29/200, Loss: 0.05520892396632796, Accuracy: 0.9854444444444445
Validation Loss: 6.60136079788208, Validation Accuracy: 0.5724


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 30/200, Loss: 0.05680226345742472, Accuracy: 0.9844
Validation Loss: 6.686694622039795, Validation Accuracy: 0.5896


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 31/200, Loss: 0.04769997975068179, Accuracy: 0.9876888888888888
Validation Loss: 7.165894508361816, Validation Accuracy: 0.5712


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 32/200, Loss: 0.05510876612312822, Accuracy: 0.9841111111111112
Validation Loss: 7.309564590454102, Validation Accuracy: 0.5806


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 33/200, Loss: 0.0405986534725966, Accuracy: 0.9893777777777778
Validation Loss: 8.034729957580566, Validation Accuracy: 0.5684


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 34/200, Loss: 0.0356706113239463, Accuracy: 0.9909111111111111
Validation Loss: 6.81319522857666, Validation Accuracy: 0.581


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 35/200, Loss: 0.0363796283276659, Accuracy: 0.9896666666666667
Validation Loss: 6.1016459465026855, Validation Accuracy: 0.5808


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 36/200, Loss: 0.03634928075248189, Accuracy: 0.9897555555555556
Validation Loss: 6.505979537963867, Validation Accuracy: 0.5928


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 37/200, Loss: 0.031957287228747, Accuracy: 0.9917555555555555
Validation Loss: 7.154705047607422, Validation Accuracy: 0.5744


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 38/200, Loss: 0.03117592782439107, Accuracy: 0.9920444444444444
Validation Loss: 6.726372718811035, Validation Accuracy: 0.5884


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 39/200, Loss: 0.027923995014442253, Accuracy: 0.9927777777777778
Validation Loss: 7.0176005363464355, Validation Accuracy: 0.5778


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 40/200, Loss: 0.024455023819721994, Accuracy: 0.9937777777777778
Validation Loss: 6.728193283081055, Validation Accuracy: 0.582


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 41/200, Loss: 0.021391771718288436, Accuracy: 0.9952444444444445
Validation Loss: 5.875554084777832, Validation Accuracy: 0.5848


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 42/200, Loss: 0.020218688673614482, Accuracy: 0.9953555555555555
Validation Loss: 5.365443229675293, Validation Accuracy: 0.5876


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 43/200, Loss: 0.01766839604789203, Accuracy: 0.9962
Validation Loss: 6.674277305603027, Validation Accuracy: 0.588


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 44/200, Loss: 0.016409777104782187, Accuracy: 0.9958888888888889
Validation Loss: 6.998018264770508, Validation Accuracy: 0.5904


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 45/200, Loss: 0.01944261593263025, Accuracy: 0.9951777777777778
Validation Loss: 7.811710357666016, Validation Accuracy: 0.578


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 46/200, Loss: 0.01713731440540662, Accuracy: 0.9959111111111111
Validation Loss: 7.317972183227539, Validation Accuracy: 0.5818


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 47/200, Loss: 0.015494327728652946, Accuracy: 0.9963111111111111
Validation Loss: 6.8448381423950195, Validation Accuracy: 0.5852


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 48/200, Loss: 0.014501025780602528, Accuracy: 0.9970444444444444
Validation Loss: 6.707420349121094, Validation Accuracy: 0.587


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 49/200, Loss: 0.011669206931052218, Accuracy: 0.9976222222222222
Validation Loss: 6.568178176879883, Validation Accuracy: 0.5946


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 50/200, Loss: 0.009411406790935776, Accuracy: 0.9980888888888889
Validation Loss: 7.030298233032227, Validation Accuracy: 0.5992


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 51/200, Loss: 0.011180224721128843, Accuracy: 0.9972888888888889
Validation Loss: 6.924497127532959, Validation Accuracy: 0.5932


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 52/200, Loss: 0.008538648854744538, Accuracy: 0.9984
Validation Loss: 7.154977321624756, Validation Accuracy: 0.5908


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 53/200, Loss: 0.014087622178279594, Accuracy: 0.9968444444444444
Validation Loss: 6.7415056228637695, Validation Accuracy: 0.5934


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 54/200, Loss: 0.01090384609468977, Accuracy: 0.9977111111111111
Validation Loss: 6.221029758453369, Validation Accuracy: 0.5892


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 55/200, Loss: 0.01176893060849081, Accuracy: 0.9974888888888889
Validation Loss: 6.72760009765625, Validation Accuracy: 0.5928


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 56/200, Loss: 0.00894631959180455, Accuracy: 0.9980888888888889
Validation Loss: 6.703739166259766, Validation Accuracy: 0.5918


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 57/200, Loss: 0.009406367476613213, Accuracy: 0.9980444444444444
Validation Loss: 6.28927755355835, Validation Accuracy: 0.5846


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 58/200, Loss: 0.009311679779908865, Accuracy: 0.9980444444444444
Validation Loss: 7.2440714836120605, Validation Accuracy: 0.5902


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 59/200, Loss: 0.008614251648052133, Accuracy: 0.9981111111111111
Validation Loss: 6.712249755859375, Validation Accuracy: 0.5898


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 60/200, Loss: 0.0068779161572506, Accuracy: 0.9987555555555555
Validation Loss: 7.255383491516113, Validation Accuracy: 0.5934


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 61/200, Loss: 0.006616876340886894, Accuracy: 0.9988
Validation Loss: 7.51605224609375, Validation Accuracy: 0.5916


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 62/200, Loss: 0.006829710653767932, Accuracy: 0.9986
Validation Loss: 6.694307804107666, Validation Accuracy: 0.594


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 63/200, Loss: 0.00994115202847907, Accuracy: 0.9979777777777777
Validation Loss: 7.377459526062012, Validation Accuracy: 0.5906


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 64/200, Loss: 0.007181655073098204, Accuracy: 0.9986444444444444
Validation Loss: 7.066564559936523, Validation Accuracy: 0.5886


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 65/200, Loss: 0.006731971404885055, Accuracy: 0.9988222222222222
Validation Loss: 6.394832611083984, Validation Accuracy: 0.5938


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 66/200, Loss: 0.0068104320895738465, Accuracy: 0.9987333333333334
Validation Loss: 6.704599380493164, Validation Accuracy: 0.5924


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 67/200, Loss: 0.00673544520659542, Accuracy: 0.9984222222222222
Validation Loss: 6.700145244598389, Validation Accuracy: 0.5908


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 68/200, Loss: 0.006446909113283972, Accuracy: 0.9988
Validation Loss: 6.110090732574463, Validation Accuracy: 0.5884


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 69/200, Loss: 0.00582188269300685, Accuracy: 0.9987555555555555
Validation Loss: 5.576595306396484, Validation Accuracy: 0.5944


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 70/200, Loss: 0.005245930591314396, Accuracy: 0.9989777777777777
Validation Loss: 7.075263023376465, Validation Accuracy: 0.5942


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 71/200, Loss: 0.004882195711012333, Accuracy: 0.9990666666666667
Validation Loss: 6.933476448059082, Validation Accuracy: 0.596


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 72/200, Loss: 0.0050696552402908135, Accuracy: 0.9991555555555556
Validation Loss: 7.047314643859863, Validation Accuracy: 0.5922


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 73/200, Loss: 0.004618651452254208, Accuracy: 0.9991333333333333
Validation Loss: 6.676936626434326, Validation Accuracy: 0.5906


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 74/200, Loss: 0.0046160808358113176, Accuracy: 0.999
Validation Loss: 6.898190975189209, Validation Accuracy: 0.5974


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 75/200, Loss: 0.004406619497605521, Accuracy: 0.9991111111111111
Validation Loss: 6.944563865661621, Validation Accuracy: 0.5962


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 76/200, Loss: 0.004000832711019893, Accuracy: 0.9992888888888889
Validation Loss: 7.243062973022461, Validation Accuracy: 0.5932


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 77/200, Loss: 0.004137080671791038, Accuracy: 0.9990666666666667
Validation Loss: 7.151128768920898, Validation Accuracy: 0.5942


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 78/200, Loss: 0.005082160622554079, Accuracy: 0.9990222222222223
Validation Loss: 7.47381591796875, Validation Accuracy: 0.5898


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 79/200, Loss: 0.003950253657259606, Accuracy: 0.9992888888888889
Validation Loss: 6.683444976806641, Validation Accuracy: 0.5942


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 80/200, Loss: 0.004143720493025708, Accuracy: 0.9990888888888889
Validation Loss: 6.543318748474121, Validation Accuracy: 0.592


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 81/200, Loss: 0.0035766989486277112, Accuracy: 0.9993111111111111
Validation Loss: 6.5677266120910645, Validation Accuracy: 0.5998


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 82/200, Loss: 0.002751624887482355, Accuracy: 0.9995555555555555
Validation Loss: 6.518649101257324, Validation Accuracy: 0.5988


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 83/200, Loss: 0.0031073986993924167, Accuracy: 0.9994666666666666
Validation Loss: 6.687126159667969, Validation Accuracy: 0.6034


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 84/200, Loss: 0.0028718689010597412, Accuracy: 0.9995111111111111
Validation Loss: 6.563406467437744, Validation Accuracy: 0.5974


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 85/200, Loss: 0.0030150585754673527, Accuracy: 0.9994
Validation Loss: 6.639641761779785, Validation Accuracy: 0.595


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 86/200, Loss: 0.004551422550823407, Accuracy: 0.9990888888888889
Validation Loss: 6.94103479385376, Validation Accuracy: 0.5994


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 87/200, Loss: 0.0030794083342909014, Accuracy: 0.9994
Validation Loss: 6.302033424377441, Validation Accuracy: 0.5938


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 88/200, Loss: 0.004114185552942724, Accuracy: 0.9991555555555556
Validation Loss: 6.574766159057617, Validation Accuracy: 0.5952


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 89/200, Loss: 0.003126283908994809, Accuracy: 0.9995333333333334
Validation Loss: 6.517897605895996, Validation Accuracy: 0.5946


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 90/200, Loss: 0.003659241569084544, Accuracy: 0.9992444444444445
Validation Loss: 7.1240997314453125, Validation Accuracy: 0.5984


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 91/200, Loss: 0.002590087216084694, Accuracy: 0.9995333333333334
Validation Loss: 7.409906387329102, Validation Accuracy: 0.6012


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 92/200, Loss: 0.003128439687482258, Accuracy: 0.9993111111111111
Validation Loss: 7.1176910400390625, Validation Accuracy: 0.5942


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 93/200, Loss: 0.0031871869740196833, Accuracy: 0.9993777777777778
Validation Loss: 7.0605902671813965, Validation Accuracy: 0.6018


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 94/200, Loss: 0.002651403086109705, Accuracy: 0.9994222222222222
Validation Loss: 7.503003120422363, Validation Accuracy: 0.5992


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 95/200, Loss: 0.003070148375215061, Accuracy: 0.9993333333333333
Validation Loss: 7.081018447875977, Validation Accuracy: 0.5962


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 96/200, Loss: 0.0031966330298574784, Accuracy: 0.9993111111111111
Validation Loss: 7.448611736297607, Validation Accuracy: 0.5976


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 97/200, Loss: 0.003636077401721711, Accuracy: 0.9992666666666666
Validation Loss: 7.7927141189575195, Validation Accuracy: 0.5938


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 98/200, Loss: 0.0031822385610694444, Accuracy: 0.9994222222222222
Validation Loss: 7.086183547973633, Validation Accuracy: 0.5934


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 99/200, Loss: 0.0028104410089847204, Accuracy: 0.9994888888888889
Validation Loss: 6.792121887207031, Validation Accuracy: 0.5918


  0%|          | 0/1407 [00:00<?, ?it/s]

Epoch 100/200, Loss: 0.0034379259776809615, Accuracy: 0.9992444444444445
Validation Loss: 7.4932966232299805, Validation Accuracy: 0.5912


  0%|          | 0/1407 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
#save the model
torch.save(resnet.state_dict(), "resnet_100_epoch.pth")

# Neural Collapse

## NC 1

## NC 2

## NC 3


## NC 4

## NC 5