In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

print("Setup complete")


Setup complete


In [4]:
# Define transformations for training and validation
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # Add color jitter
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
])

# Download CIFAR-100 dataset
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

print("Data loaded")


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [02:50<00:00, 989129.88it/s] 


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
Data loaded


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [7]:
# Load pretrained ResNet101 and modify for CIFAR-100 classes (100)
model = torchvision.models.resnet101(pretrained=True)

# Change first conv layer to kernel=3, stride=1 to accommodate 32x32 CIFAR images
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()  # Remove maxpool to keep feature map size

model.fc = nn.Linear(model.fc.in_features, 100)  # CIFAR-100 classes

model = model.to(device)


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)


print(f"Model ready on {device}")




Model ready on cuda


In [8]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    return running_loss / len(loader.dataset)

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

num_epochs = 100
for epoch in range(num_epochs):
    loss = train_epoch(model, train_loader, criterion, optimizer, device)
    accuracy = evaluate(model, test_loader, device)
    scheduler.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}, Test Accuracy: {accuracy*100:.2f}%")


100%|██████████| 391/391 [03:03<00:00,  2.13it/s]


Epoch [1/100], Loss: 1.9476, Test Accuracy: 64.82%


100%|██████████| 391/391 [03:04<00:00,  2.12it/s]


Epoch [2/100], Loss: 0.9626, Test Accuracy: 70.81%


100%|██████████| 391/391 [03:05<00:00,  2.11it/s]


Epoch [3/100], Loss: 0.7303, Test Accuracy: 72.56%


100%|██████████| 391/391 [03:05<00:00,  2.11it/s]


Epoch [4/100], Loss: 0.6011, Test Accuracy: 75.92%


100%|██████████| 391/391 [03:05<00:00,  2.11it/s]


Epoch [5/100], Loss: 0.5039, Test Accuracy: 74.61%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [6/100], Loss: 0.4351, Test Accuracy: 76.68%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [7/100], Loss: 0.3904, Test Accuracy: 76.64%


100%|██████████| 391/391 [03:05<00:00,  2.11it/s]


Epoch [8/100], Loss: 0.3400, Test Accuracy: 76.44%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [9/100], Loss: 0.3159, Test Accuracy: 77.55%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [10/100], Loss: 0.2776, Test Accuracy: 77.08%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [11/100], Loss: 0.2563, Test Accuracy: 76.80%


100%|██████████| 391/391 [03:05<00:00,  2.11it/s]


Epoch [12/100], Loss: 0.2374, Test Accuracy: 77.82%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [13/100], Loss: 0.2304, Test Accuracy: 78.71%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [14/100], Loss: 0.2103, Test Accuracy: 77.98%


100%|██████████| 391/391 [03:05<00:00,  2.11it/s]


Epoch [15/100], Loss: 0.2020, Test Accuracy: 77.52%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [16/100], Loss: 0.1946, Test Accuracy: 78.28%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [17/100], Loss: 0.1648, Test Accuracy: 78.58%


100%|██████████| 391/391 [03:05<00:00,  2.11it/s]


Epoch [18/100], Loss: 0.1572, Test Accuracy: 77.64%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [19/100], Loss: 0.1531, Test Accuracy: 79.25%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [20/100], Loss: 0.1414, Test Accuracy: 78.37%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [21/100], Loss: 0.1462, Test Accuracy: 78.71%


100%|██████████| 391/391 [03:05<00:00,  2.11it/s]


Epoch [22/100], Loss: 0.1326, Test Accuracy: 79.22%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [23/100], Loss: 0.1146, Test Accuracy: 79.66%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [24/100], Loss: 0.1163, Test Accuracy: 78.57%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [25/100], Loss: 0.1164, Test Accuracy: 78.71%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [26/100], Loss: 0.1034, Test Accuracy: 79.61%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [27/100], Loss: 0.0983, Test Accuracy: 79.91%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [28/100], Loss: 0.0992, Test Accuracy: 79.39%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [29/100], Loss: 0.0924, Test Accuracy: 79.48%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [30/100], Loss: 0.0856, Test Accuracy: 81.24%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [31/100], Loss: 0.0710, Test Accuracy: 80.48%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [32/100], Loss: 0.0785, Test Accuracy: 80.05%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [33/100], Loss: 0.0815, Test Accuracy: 80.47%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [34/100], Loss: 0.0662, Test Accuracy: 80.82%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [35/100], Loss: 0.0598, Test Accuracy: 81.18%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [36/100], Loss: 0.0594, Test Accuracy: 80.78%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [37/100], Loss: 0.0608, Test Accuracy: 79.98%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [38/100], Loss: 0.0518, Test Accuracy: 80.96%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [39/100], Loss: 0.0510, Test Accuracy: 80.96%


100%|██████████| 391/391 [03:05<00:00,  2.10it/s]


Epoch [40/100], Loss: 0.0434, Test Accuracy: 81.92%


100%|██████████| 391/391 [03:07<00:00,  2.08it/s]


Epoch [41/100], Loss: 0.0377, Test Accuracy: 81.61%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [42/100], Loss: 0.0359, Test Accuracy: 81.55%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [43/100], Loss: 0.0323, Test Accuracy: 81.53%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [44/100], Loss: 0.0302, Test Accuracy: 81.43%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [45/100], Loss: 0.0318, Test Accuracy: 82.08%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [46/100], Loss: 0.0292, Test Accuracy: 82.43%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [47/100], Loss: 0.0243, Test Accuracy: 81.65%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [48/100], Loss: 0.0230, Test Accuracy: 82.70%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [49/100], Loss: 0.0197, Test Accuracy: 82.50%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [50/100], Loss: 0.0175, Test Accuracy: 83.04%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [51/100], Loss: 0.0138, Test Accuracy: 83.46%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [52/100], Loss: 0.0170, Test Accuracy: 82.76%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [53/100], Loss: 0.0134, Test Accuracy: 83.49%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [54/100], Loss: 0.0103, Test Accuracy: 83.92%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [55/100], Loss: 0.0110, Test Accuracy: 83.51%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [56/100], Loss: 0.0092, Test Accuracy: 83.86%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [57/100], Loss: 0.0084, Test Accuracy: 84.15%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [58/100], Loss: 0.0063, Test Accuracy: 84.10%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [59/100], Loss: 0.0056, Test Accuracy: 84.25%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [60/100], Loss: 0.0046, Test Accuracy: 84.45%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [61/100], Loss: 0.0050, Test Accuracy: 84.26%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [62/100], Loss: 0.0045, Test Accuracy: 84.27%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [63/100], Loss: 0.0042, Test Accuracy: 84.81%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [64/100], Loss: 0.0048, Test Accuracy: 84.91%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [65/100], Loss: 0.0044, Test Accuracy: 84.98%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [66/100], Loss: 0.0043, Test Accuracy: 84.80%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [67/100], Loss: 0.0037, Test Accuracy: 84.96%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [68/100], Loss: 0.0037, Test Accuracy: 85.08%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [69/100], Loss: 0.0040, Test Accuracy: 84.96%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [70/100], Loss: 0.0035, Test Accuracy: 85.13%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [71/100], Loss: 0.0035, Test Accuracy: 85.19%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [72/100], Loss: 0.0036, Test Accuracy: 85.17%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [73/100], Loss: 0.0037, Test Accuracy: 85.15%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [74/100], Loss: 0.0034, Test Accuracy: 85.20%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [75/100], Loss: 0.0034, Test Accuracy: 85.20%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [76/100], Loss: 0.0034, Test Accuracy: 85.23%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [77/100], Loss: 0.0035, Test Accuracy: 85.38%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [78/100], Loss: 0.0035, Test Accuracy: 85.31%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [79/100], Loss: 0.0034, Test Accuracy: 85.45%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [80/100], Loss: 0.0035, Test Accuracy: 85.37%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [81/100], Loss: 0.0035, Test Accuracy: 85.15%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [82/100], Loss: 0.0033, Test Accuracy: 85.44%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [83/100], Loss: 0.0032, Test Accuracy: 85.41%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [84/100], Loss: 0.0034, Test Accuracy: 85.27%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [85/100], Loss: 0.0032, Test Accuracy: 85.42%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [86/100], Loss: 0.0032, Test Accuracy: 85.50%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [87/100], Loss: 0.0032, Test Accuracy: 85.58%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [88/100], Loss: 0.0032, Test Accuracy: 85.42%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [89/100], Loss: 0.0033, Test Accuracy: 85.42%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [90/100], Loss: 0.0032, Test Accuracy: 85.47%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [91/100], Loss: 0.0032, Test Accuracy: 85.36%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [92/100], Loss: 0.0033, Test Accuracy: 85.46%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [93/100], Loss: 0.0032, Test Accuracy: 85.31%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [94/100], Loss: 0.0031, Test Accuracy: 85.50%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [95/100], Loss: 0.0032, Test Accuracy: 85.42%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [96/100], Loss: 0.0033, Test Accuracy: 85.54%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [97/100], Loss: 0.0034, Test Accuracy: 85.59%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [98/100], Loss: 0.0032, Test Accuracy: 85.50%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [99/100], Loss: 0.0032, Test Accuracy: 85.58%


100%|██████████| 391/391 [03:06<00:00,  2.10it/s]


Epoch [100/100], Loss: 0.0032, Test Accuracy: 85.44%


In [10]:
# Define the path where you want to save the weights
SAVE_PATH = "savepoints/teacher_resnet101.pth"

# Save the model's state_dict
# This saves only the learned parameters, not the model's architecture
torch.save(model.state_dict(), SAVE_PATH)

print(f"Model weights successfully saved to {SAVE_PATH}")

Model weights successfully saved to savepoints/teacher_resnet101.pth
