In [1]:
import torch
import torchvision
from torchinfo import summary
import matplotlib.pyplot as plt
from tqdm import tqdm

import sys
sys.path.insert(0, '../scripts') # add 'scrips' subfolder to sys path for easier import 
from Dataset import get_cifar10, get_cifar100, CIFAR100_Fine_labels, CIFAR10_labels
from Dataset import ImageNetDataset, ConvertToPlotableImage
from Models import get_Model
from Utils import evaluate

from models.ghostnet import ghostnet
from models.mobilenetv2 import MobileNet_v2_x0_5, MobileNet_v2_x1_0

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [8]:
model_name = 'MobileNetV2_x0_5'
model = get_Model(model_name,100) # 100 class for cifar100

#model = ghostnet()
#model = torchvision.models.mobilenet_v2(100)
#model = torchvision.models.shufflenet_v2_x1_0()
#model = MobileNet_v2_x1_0(10)

model.to(device)
summary(model, (1,3,32,32))

6 0.5


Layer (type:depth-idx)                        Output Shape              Param #
MobileNetV2                                   [1, 100]                  --
├─Sequential: 1-1                             [1, 16, 16, 16]           --
│    └─Conv2d: 2-1                            [1, 16, 16, 16]           432
│    └─BatchNorm2d: 2-2                       [1, 16, 16, 16]           32
│    └─ReLU: 2-3                              [1, 16, 16, 16]           --
├─ModuleList: 1-2                             --                        --
│    └─Sequential: 2-4                        [1, 8, 16, 16]            --
│    │    └─InvertedResidual: 3-1             [1, 8, 16, 16]            608
│    └─Sequential: 2-5                        [1, 12, 16, 16]           --
│    │    └─InvertedResidual: 3-2             [1, 12, 16, 16]           1,608
│    │    └─InvertedResidual: 3-3             [1, 12, 16, 16]           2,688
│    └─Sequential: 2-6                        [1, 16, 8, 8]             --
│    │    └─

In [7]:
print(model)

MobileNetV2(
  (conv1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (blocks): ModuleList(
    (0): Sequential(
      (0): InvertedResidual(
        (conv): Sequential(
          (0): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
          (6): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (7): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (1): Sequential(

In [5]:
# inspect the basic blocks

In [6]:
from models.resnet import BasicBlock
BB = BasicBlock(16,16)
BB.to(device)
summary(BB, (1,16,32,32))

Layer (type:depth-idx)                   Output Shape              Param #
BasicBlock                               [1, 16, 32, 32]           --
├─Conv2d: 1-1                            [1, 16, 32, 32]           2,304
├─BatchNorm2d: 1-2                       [1, 16, 32, 32]           32
├─ReLU: 1-3                              [1, 16, 32, 32]           --
├─Conv2d: 1-4                            [1, 16, 32, 32]           2,304
├─BatchNorm2d: 1-5                       [1, 16, 32, 32]           32
Total params: 4,672
Trainable params: 4,672
Non-trainable params: 0
Total mult-adds (M): 4.72
Input size (MB): 0.07
Forward/backward pass size (MB): 0.52
Params size (MB): 0.02
Estimated Total Size (MB): 0.61

In [7]:
#print(model)

In [8]:
dummy_in = torch.rand([1,3,32,32], device= device)
feats, logits = model(dummy_in, is_feat = True)

In [9]:
feats[0].shape

torch.Size([1, 16, 32, 32])

In [10]:
feats[1].shape

torch.Size([1, 64, 32, 32])

In [11]:
feats[2].shape

torch.Size([1, 128, 16, 16])

In [12]:
feats[3].shape

torch.Size([1, 256, 8, 8])

In [13]:
feats[4].shape # after average pooling of the 64x8x8 activation maps

torch.Size([1, 256])

In [14]:
logits.shape

torch.Size([1, 100])

## Load pre-train weight for model trained from-scratch on cifar100

In [15]:
#model.load_state_dict(torch.load('/home/chitraz/Documents/UoS_MSc/EEEM056_Project/Experiments/saves/ResNet-14_Exp1A_182epochs_Val:64.2_2023-05-24.pth'))
model.load_state_dict(torch.load('/home/chitraz/Documents/UoS_MSc/EEEM056_Project/Experiments/saves/WRN-40-4_pretrains_200epochs_Val:76.86_2023-06-23.pth'))

<All keys matched successfully>

## define functions to pre-compute and fetch the logits and also intermediate feature tensors

In [28]:
# for logits

def PreCompute_logits(model, device, Dataloader_train):
    MAP_ImgIdx_logits = {}

    with torch.no_grad():
        model.to(device)
        model.eval()
        for _ , batch in tqdm(enumerate(Dataloader_train), total = len(Dataloader_train)):    
            img_batch, _ , idxs = batch # unpack        
            img_batch = img_batch.to(device) # put images on gpu mem
            pred_logits = model(img_batch) # forward pass to get predicted logits 

            # save mapping as a Dict: image index -> output logits 
            for (idx, logits) in zip(idxs,pred_logits): 

                MAP_ImgIdx_logits[idx.item()] = logits

    return MAP_ImgIdx_logits

def fetch_logtis(teacher_map, idxs):
    batch_size = len(idxs)
    # get number of classes by checking size of logit vector
    num_class = len(list(teacher_map.values())[0])

    logits_batch = torch.zeros((batch_size, num_class), dtype=torch.float32)
    count = 0 

    for idx in idxs:
        logits_batch[count, :] = teacher_map[idx.item()]
        count += 1

    return logits_batch

In [29]:
# for intermediate feature tensors
def PreCompute_feats(model, device, Dataloader_train):

    MAP_ImgIdx_feats1 = {}
    MAP_ImgIdx_feats2 = {}
    MAP_ImgIdx_feats3 = {}
    
    with torch.no_grad():
        model.to(device)
        model.eval()
        for _ , batch in tqdm(enumerate(Dataloader_train), total = len(Dataloader_train)):    
            img_batch, _ , idxs = batch # unpack        
            img_batch = img_batch.to(device) # put images on gpu mem

            feats, _ = model(img_batch, is_feat=True, preact=True) # forward pass to get activations  

            # save mapping as a Dict: image index -> feature map [16x32x32]
            for (idx, feat) in zip(idxs,feats[1]): 
                MAP_ImgIdx_feats1[idx.item()] = feat
            
            # save mapping as a Dict: image index -> feature map [32x16x16]
            for (idx, feat) in zip(idxs,feats[2]): 
                MAP_ImgIdx_feats2[idx.item()] = feat
            
            # save mapping as a Dict: image index -> feature map [64x8x8]
            for (idx, feat) in zip(idxs,feats[3]): 
                MAP_ImgIdx_feats3[idx.item()] = feat
            
                
    return MAP_ImgIdx_feats1, MAP_ImgIdx_feats2, MAP_ImgIdx_feats3

# construct the [Batch_size x 16 x 32 x 32] tensor from relavent [16 x 32 x 32] tensors found in our pre-computed mapping 
def fetch_feats1(teacher_map, idxs):
    batch_size = len(idxs)
    feat_batch = torch.zeros((batch_size, 16,32,32), dtype=torch.float32)
    count = 0

    for idx in idxs:
        feat_batch[count,:,:,:] = teacher_map[idx.item()]
        count += 1

    return feat_batch

# construct the [Batch_size x 32 x 16 x 16] tensor from relavent [32 x 16 x 16] tensors found in our pre-computed mapping 
def fetch_feats2(teacher_map, idxs):
    batch_size = len(idxs)
    feat_batch = torch.zeros((batch_size, 32,16,16), dtype=torch.float32)
    count = 0

    for idx in idxs:
        feat_batch[count,:,:,:] = teacher_map[idx.item()]
        count += 1

    return feat_batch

# construct the [Batch_size x 64 x 8 x 8] tensor from relavent [64 x 8 x 8] tensors found in our pre-computed mapping 
def fetch_feats3(teacher_map, idxs):
    batch_size = len(idxs)
    feat_batch = torch.zeros((batch_size, 64,8,8), dtype=torch.float32)
    count = 0

    for idx in idxs:
        feat_batch[count,:,:,:] = teacher_map[idx.item()]
        count += 1

    return feat_batch

In [30]:
# to completely clear model from gpu mem 

#del model
#gc.collect()
#torch.cuda.empty_cache()

## create dataloader for cifar 10 train set

In [31]:
dataset_train = get_cifar10('train', Dataset_dir = '/home/chitraz/Documents/UoS_MSc/EEEM056_Project/Experiments/dataset')
dataloader_train  = torch.utils.data.DataLoader(dataset_train, batch_size = 128, shuffle = False, num_workers = 16)

## check the pre-computed features/logits 

In [32]:
#torch.manual_seed(0)
#torch.use_deterministic_algorithms(True)
#torch.cuda.manual_seed_all(42)


# precompute logits/features
logit_map = PreCompute_logits(model,device,dataloader_train)
F1_map,F2_map,F3_map = PreCompute_feats(model,device,dataloader_train)

model.to(device)
model.eval()
for itr, Batch in enumerate(dataloader_train):
    # unpack 
    Images, _, idxs = Batch
    Images = Images.to(device)
    
    # get pre computed logits 
    F1 = fetch_feats1(F1_map, idxs).to(device)
    F2 = fetch_feats2(F2_map, idxs).to(device)
    F3 = fetch_feats3(F3_map, idxs).to(device)
    lo = fetch_logtis(logit_map,idxs).to(device)
    
    # compute it via a forward pass (to verify it the same)
    feats, logits = model(Images, is_feat=True, preact=True) # this non-deter
    
    # debug
    print(idxs) 
    print(logits[0,:]) 
    print(lo[0,:]) 
    
    # compare the fetched pre-comuted logits/feauture to one done on the fly 
    print(torch.all(logits==lo))
    print(torch.all(F1==feats[1]))
    print(torch.all(F2==feats[2]))
    print(torch.all(F3==feats[3]))
    
    # DIFFERENT! WHY? Oh, it maybe due to the random augmentation on each train samples! 
    
    sys.exit()

100%|█████████████████████████████████████████████████████| 391/391 [00:01<00:00, 330.71it/s]
100%|█████████████████████████████████████████████████████| 391/391 [00:01<00:00, 323.78it/s]


tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
tensor([-7.8895e-01,  4.7946e-01,  3.1003e+00, -1.5134e+00,  6.0077e-01,
        -3.3058e+00,  7.0202e+00,  1.8596e+00,  3.1998e-01, -3.8095e+00,
         2.9673e+00,  4.0066e+00, -1.5725e+00, -4.1587e+00,  1.8725e+00,
         2.9815e+00, -1.2827e+00, -5.7274e+00,  2.2861e-01, 

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


## DIFFERENT RESULTS DUE TO THE RANDOM AUGMENATION OF THE TRAIN SAMPLES!!
## Pre-computed features are for a single augmentation setting.

In [None]:
# try the valid of test set!

In [33]:
dataset_test = get_cifar10('test', Dataset_dir = '/home/chitraz/Documents/UoS_MSc/EEEM056_Project/Experiments/dataset')
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size = 128, shuffle = False, num_workers = 16)

In [34]:
logit_map = PreCompute_logits(model,device,dataloader_test)
F1_map,F2_map,F3_map = PreCompute_feats(model,device,dataloader_test)

model.to(device)
model.eval()
for itr, Batch in enumerate(dataloader_test):
    # unpack 
    Images, _, idxs = Batch
    Images = Images.to(device)
    
    # get pre computed logits 
    F1 = fetch_feats1(F1_map, idxs).to(device)
    F2 = fetch_feats2(F2_map, idxs).to(device)
    F3 = fetch_feats3(F3_map, idxs).to(device)
    lo = fetch_logtis(logit_map,idxs).to(device)
      
    # compute it via a forward pass
    feats, logits = model(Images, is_feat=True, preact=False) 

    # debug
    print(idxs) 
    print(logits[0,:]) 
    print(lo[0,:]) 
    
    # check 
    print(torch.all(logits==lo))
    print(torch.all(F1==feats[1]))
    print(torch.all(F2==feats[2]))
    print(torch.all(F3==feats[3]))
    
    sys.exit()

100%|███████████████████████████████████████████████████████| 79/79 [00:00<00:00, 304.60it/s]
100%|███████████████████████████████████████████████████████| 79/79 [00:00<00:00, 206.13it/s]


tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
tensor([-3.9151,  3.2859,  5.1106,  3.5057,  3.3187, -2.8528,  3.6599,  4.3536,
         1.0994, -0.7819,  6.0692,  8.3409, -9.6483, -0.8250,  2.0864, -3.0817,
         6.1281, -6.6057, -1.3885, -0.9334, -5.8815, -1.1331,  5.5258, -5.6761,
        -4.4374, -0.5604,  0.2083, -2.6

SystemExit: 

In [37]:
feats[-2].shape

torch.Size([128, 64, 8, 8])

In [44]:
model.avgpool(x).shape

torch.Size([1, 64, 1, 1])

In [45]:
model.fc

Linear(in_features=64, out_features=100, bias=True)