In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10, CIFAR100, ImageNet
from tqdm import tqdm

In [None]:
# Load the model
alexnet = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)

In [None]:
DATASET='imagenet'

In [None]:
alexnet

In [None]:
# freeze layers for finetuning
for param in alexnet.parameters():
    param.requires_grad = False

# unfreeze last layer to finetune
for param in alexnet.classifier.parameters():
    param.requires_grad = True

In [None]:
# Define preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # All 3 models expect 224x224 images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # normalization constants for ImageNet-1k (pre-training data)
])

In [None]:
# Load dataset
train_data = ImageNet(root='/root/.cache/kagglehub/datasets/titericz/imagenet1k-val/versions/1', split="val", transform=transform)

In [None]:
# split data in half - half for alexnet, half for strong model
# from sklearn.model_selection import train_test_split
# import numpy as np
# from torch.utils.data import Subset

In [None]:
# train_indices, test_indices = train_test_split(np.arange(len(train_data)), test_size=.2, shuffle=True, stratify=train_data.targets)

In [None]:
train_data = torch.load(f'./data/{DATASET}-split/finetune.pth')
# wtsg_data = torch.load('./data/{DATASET}-split/wtsg.pth')

In [None]:
# finetune_data = Subset(train_data, train_indices)
# test_data = Subset(train_data, test_indices)

In [None]:
# create dataloaders for training data
train_loader = DataLoader(train_data, batch_size=128, shuffle=False)
# test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

In [None]:
# save both datasets for later, in case run gets interrupted
# torch.save(finetune_data, './data/imagenet-split/finetune.pth')
# torch.save(test_data, './data/imagenet-split/test.pth')

In [None]:
# # Define optimizer and loss function
# optimizer = optim.Adam(alexnet.classifier[6].parameters(), lr=1e-4) # note: only training classifier
criterion = nn.CrossEntropyLoss()

In [None]:
# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
alexnet = alexnet.to(device)

In [None]:
alexnet.train()

NUM_EPOCHS = 20  # Fine-tuning for 3 epochs

cum_loss = 0.0
total = 0

for epoch in range(NUM_EPOCHS):
    progress_bar = tqdm(finetune_loader, leave=True)
    progress_bar.set_description(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")
    for idx, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = alexnet(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        cum_loss += loss.item()
        total += len(labels)
        
        avg_loss = cum_loss / total
        
        # Update the tqdm bar with loss and epoch
        progress_bar.set_postfix(loss=avg_loss)

In [None]:
# Save the fine-tuned model
torch.save(alexnet.state_dict(), 'models/cifar100/alexnet_cifar100_frozen3.pth')

### Generate Pseudolabels

In [None]:
alexnet.eval()

with torch.no_grad():
    correct = 0
    total = 0
    progress_bar = tqdm(train_loader, leave=True)
    pseudolabels = []
    for idx, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        outputs = alexnet(images)
        _, plabels = torch.max(outputs, 1)
        pseudolabels += plabels.cpu()
        
        correct += (plabels == labels).sum().item()
        total += len(labels)
        acc = 100 * correct / total
        
        progress_bar.set_postfix({"Acc": acc})

In [None]:
pseudolabels = torch.tensor(pseudolabels).tolist() # remove weirdness where entries are all tensors instead of numbers

In [None]:
class WeakLabeledData(Dataset):
    def __init__(self, original_dataset, weak_labels):
        self.dataset = original_dataset
        self.weak_labels = weak_labels
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        image, true_label = self.dataset[index]
        weak_label = self.weak_labels[index]
        return image, true_label, weak_label

In [None]:
weak_data = WeakLabeledData(train_data, pseudolabels)

In [None]:
torch.save(weak_data, './data/imagenet-split/weak-labeled-half.pth')

### Evaluate AlexNet accuracy

In [None]:
# load the test data
# test_data = CIFAR100(root='./data', train=False, transform=transform)
# test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

In [None]:
# alexnet.load_state_dict(torch.load('alexnet_cifar100.pth'))

In [None]:
alexnet.eval()

# Initialize metrics
correct = 0
total = 0
test_loss = 0.0

with torch.no_grad():  # Disable gradient computation to save memory and computation
    progress_bar = tqdm(test_loader, leave=True)
    for idx, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device) # send data to gpu
        
        # Forward pass
        outputs = alexnet(images)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        test_loss += loss.item()  # Accumulate loss
        
        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total += len(labels)
        correct += (predicted == labels).sum().item()
        
        # Calculate metrics for display
        avg_loss = test_loss / total
        accuracy = 100 * correct / total

        # Update the tqdm bar with loss and accuracy
        progress_bar.set_postfix(loss=avg_loss, accuracy=accuracy)

# Calculate final metrics
avg_loss = test_loss / len(test_loader)
accuracy = 100 * correct / total

print(f"Test Loss: {avg_loss:.4f}")
print(f"Test Accuracy: {accuracy:.2f}%")

In [None]:
from utils import evaluate_pseudolabels

In [None]:
evaluate_pseudolabels(DataLoader(weak_data, batch_size=128, shuffle=False))