<a href="https://colab.research.google.com/github/imthelizardking/cmp719-project/blob/main/cmp719_project_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Main text

Mount gdrive for saving weights etc.

In [1]:
from google.colab import drive
drive.mount('/content/drive/')
!cd '/content/drive/MyDrive'

Mounted at /content/drive/


Import required packages

In [2]:
!pip install torch torchvision
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F



ResNet-56 Model:

In [19]:
# Define the ResNet-56 model
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        # Adjust the number of input channels for the skip connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)  # Skip connection
        out = self.relu(out)
        return out

class ResNet56(nn.Module):
    def __init__(self, num_classes):
        super(ResNet56, self).__init__()
        self.in_planes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(16, 9, stride=1)
        self.layer2 = self._make_layer(32, 9, stride=2)
        self.layer3 = self._make_layer(64, 9, stride=2)
        self.linear = nn.Linear(64, num_classes)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(BasicBlock(self.in_planes, planes, stride))
            self.in_planes = planes * BasicBlock.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = nn.ReLU()(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = nn.AdaptiveAvgPool2d(1)(out)
        out = torch.flatten(out, 1)
        out = self.linear(out)
        return out

Set configuration for training ResNet-56 w/ cifar-100:

In [20]:
# Set device to GPU if available, otherwise use CPU
device = torch.device("cuda")

# Load CIFAR-100 dataset
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [21]:
# Create ResNet-56 model instance
model_resnet56 = ResNet56(num_classes=100).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model_resnet56.parameters(), lr=0.001)
EPOCHS_RESNET56 = 300

Train ResNet-56 w/ cifar-100

In [22]:
# Training loop
for epoch in range(EPOCHS_RESNET56):  # Number of epochs
    model_resnet56.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model_resnet56(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i % 100 == 99:  # Print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

    # Validation
    model_resnet56.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model_resnet56(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print('Accuracy on test set after epoch %d: %.2f %%' % (epoch + 1, accuracy))

print('Training finished.')

[1,   100] loss: 4.442
[1,   200] loss: 4.092
[1,   300] loss: 3.925
Accuracy on test set after epoch 1: 9.82 %
[2,   100] loss: 3.683
[2,   200] loss: 3.569
[2,   300] loss: 3.483
Accuracy on test set after epoch 2: 18.80 %
[3,   100] loss: 3.296
[3,   200] loss: 3.209
[3,   300] loss: 3.099
Accuracy on test set after epoch 3: 22.28 %
[4,   100] loss: 2.923
[4,   200] loss: 2.842
[4,   300] loss: 2.767
Accuracy on test set after epoch 4: 29.28 %
[5,   100] loss: 2.629
[5,   200] loss: 2.577
[5,   300] loss: 2.511
Accuracy on test set after epoch 5: 32.79 %
[6,   100] loss: 2.365
[6,   200] loss: 2.362
[6,   300] loss: 2.324
Accuracy on test set after epoch 6: 35.71 %
[7,   100] loss: 2.237
[7,   200] loss: 2.188
[7,   300] loss: 2.135
Accuracy on test set after epoch 7: 40.29 %
[8,   100] loss: 2.069
[8,   200] loss: 2.055
[8,   300] loss: 2.031
Accuracy on test set after epoch 8: 42.33 %
[9,   100] loss: 1.937
[9,   200] loss: 1.911
[9,   300] loss: 1.898
Accuracy on test set after e

Save trained ResNet-56 weights:

In [1]:
torch.save(model_resnet56.state_dict(), '/content/drive/MyDrive/719_project/trained_weights/resnet_model_weights.pth')

NameError: ignored

Calculate Top-1 accuracy for trained ResNet-56 model and cifar-100 dataset (in paper, 70.43%):

In [None]:
with torch.no_grad(): # for making sure there is no training, just inference
    model_resnet56.eval()  # model to eval. mode
    for images, labels in testloader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model_resnet56(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

top1_accuracy = 100 * correct / total
print('Top-1 Accuracy: {:.2f}%'.format(top1_accuracy))

Top-1 Accuracy: 41.61%


# **Vision Transformer w/ Feature Guidance:**

Vision Transformer w/ feature guidance:

In [None]:
class T2TViT(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, num_classes, embed_dim, depth, heads, mlp_dim, token_dim):
        super(T2TViT, self).__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.depth = depth
        self.heads = heads
        self.mlp_dim = mlp_dim
        self.num_patches = (image_size // patch_size) ** 2
        self.num_tokens = self.num_patches
        self.token_dim = token_dim

        self.patch_embeddings = nn.Conv2d(in_channels, self.token_dim, kernel_size=patch_size, stride=patch_size)
        self.token_embeddings = nn.Parameter(torch.randn(1, self.num_tokens, self.token_dim))

        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embed_dim, heads, mlp_dim),
            depth
        )

        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # Patch embeddings
        x = self.patch_embeddings(x)

        # Reshaping the patches
        x = x.flatten(2).transpose(1, 2)

        # Token embeddings
        tokens = self.token_embeddings.repeat(x.shape[0], 1, 1)

        # Concatenate token embeddings with patch embeddings
        x = torch.cat((tokens, x), dim=1)  # Concatenate along the second dimension

        # Transformer layers
        x = self.transformer(x)

        # Global pooling (mean)
        x = x.mean(dim=1)

        # Classification
        x = self.classifier(x)

        return x

In [None]:
# Define the Vision Transformer model
class VisionTransformer_fg(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim):
        super(VisionTransformer_fg, self).__init__()
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2  # Assuming RGB images

        self.patch_embedding = nn.Sequential(
            nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(start_dim=2)
        )
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(dim, heads, dim_feedforward=mlp_dim),
            num_layers=depth
        )
        self.classifier = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.mean(dim=0)
        x = self.classifier(x)
        return x

Set training configuration for Vision Transformer w/ feature guidance

In [None]:
# Set up hyperparameters and data loaders
torch.cuda.empty_cache()
image_size = 32
patch_size = 16
num_classes = 100
dim = 128 #64
depth = 7 #10
heads = 4
token_dim = 128
mlp_dim = 256 #512
batch_size = 128
learning_rate = 1e-3
epochs = 50
BETA = 2.5 # scaler for feature guidance loss
device = torch.device("cuda")


# Create an instance of the model
#model_ViT = VisionTransformer_fg(image_size, patch_size, num_classes, dim, depth, heads, mlp_dim)
model_ViT = T2TViT(image_size, patch_size, 3, num_classes, dim, depth, heads, mlp_dim, token_dim)

model_ViT.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_ViT.parameters(), lr=learning_rate, weight_decay=0.01)
# Create data loaders (replace with your own datasets)
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)



Files already downloaded and verified


If pre-trained weights will be used, run following code snippet:

In [None]:
model_ViT.load_state_dict(torch.load('/content/drive/MyDrive/719_project/trained_weights/ViT_model_weights.pth'))

Vision Transformer Trainer:

In [None]:
# Training loop
#model_resnet56.eval() # no training for cnn, just eval.
model_ViT.train()
for epoch in range(epochs):
    train_loss = 0.0
    for images, labels in train_dataloader:
        images, labels = images.to(device), labels.to(device)
        # zero grads
        optimizer.zero_grad()
        # Forward pass
        outputs = model_ViT(images)
        loss = criterion(outputs, labels)
        # Backward pass
        loss.backward()
        # Update weights
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_train_loss:.4f}")

Epoch 1/50 - Loss: 4.3495
Epoch 2/50 - Loss: 4.2698
Epoch 3/50 - Loss: 4.2106
Epoch 4/50 - Loss: 4.1531
Epoch 5/50 - Loss: 4.1188
Epoch 6/50 - Loss: 4.0963
Epoch 7/50 - Loss: 4.0849
Epoch 8/50 - Loss: 4.0796
Epoch 9/50 - Loss: 4.0713
Epoch 10/50 - Loss: 4.0649
Epoch 11/50 - Loss: 4.0593
Epoch 12/50 - Loss: 4.0578
Epoch 13/50 - Loss: 4.0503
Epoch 14/50 - Loss: 4.0532
Epoch 15/50 - Loss: 4.0450
Epoch 16/50 - Loss: 4.0472
Epoch 17/50 - Loss: 4.0429
Epoch 18/50 - Loss: 4.0441
Epoch 19/50 - Loss: 4.0366
Epoch 20/50 - Loss: 4.0358
Epoch 21/50 - Loss: 4.0352
Epoch 22/50 - Loss: 4.0337
Epoch 23/50 - Loss: 4.0305
Epoch 24/50 - Loss: 4.0273
Epoch 25/50 - Loss: 4.0277
Epoch 26/50 - Loss: 4.0224
Epoch 27/50 - Loss: 4.0266
Epoch 28/50 - Loss: 4.0211
Epoch 29/50 - Loss: 4.0210
Epoch 30/50 - Loss: 4.0200
Epoch 31/50 - Loss: 4.0184
Epoch 32/50 - Loss: 4.0156
Epoch 33/50 - Loss: 4.0162
Epoch 34/50 - Loss: 4.0186
Epoch 35/50 - Loss: 4.0146
Epoch 36/50 - Loss: 4.0120
Epoch 37/50 - Loss: 4.0116
Epoch 38/5

Calculate Top-1 Accuracy for Vision Transformer w/ feature guidance:

In [None]:
torch.save(model_ViT.state_dict(), '/content/drive/MyDrive/719_project/trained_weights/ViT_model_weights.pth')

In [None]:
with torch.no_grad(): # for making sure there is no training, just inference
    model_ViT.eval()  # model to eval. mode
    total, correct = 0, 0
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_ViT(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
top1_accuracy = 100 * correct / total
print('Top-1 Accuracy: {:.2f}%'.format(top1_accuracy))

Top-1 Accuracy: 7.43%


# TEMP:

In [None]:
# Define the training loop
def train_ViT_with_fg(model_ViT, model_cnn, dataloader, criterion, optimizer, device, BETA):

    ###def get_features_hook(module, input, output):
      #### Store the intermediate features in a global variable
      ###global student_features
      ###student_features = output
    ###def get_teacher_features_hook(module, input, output):
        #### Store the intermediate features in a global variable
        ###global teacher_features
        ###teacher_features = output

    model_cnn.eval() # no training for cnn, just eval.
    model_ViT.train()
    total_loss = 0.0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        #images, labels = images.cuda(), labels.cuda() # add this line

        optimizer.zero_grad()
        outputs = model_ViT(images)

        #### hook cnn and ViT for intermediate feature extraction #
        ###criterion_fg = nn.MSELoss()
        ###model_ViT.register_forward_hook(get_features_hook)
        #### Register a forward hook to extract features from the teacher model
        ###model_cnn.register_forward_hook(get_teacher_features_hook)
        ####loss_fg = criterion_fg(student_features, teacher_features.detach())  # detach the teacher features to prevent backpropagation through the teacher
        ###loss_fg = 0
        # hook cnn and ViT for intermediate feature extraction #
        loss_cls = criterion(outputs, labels) # cross-entropy loss
        loss = loss_cls + BETA * 0 # loss_fg
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
    return total_loss / len(dataloader.dataset)