In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100
from tqdm import tqdm

from utils import WeakLabeledData

In [2]:
GROUND_TRUTH = True
DATASET='imagenet'
STARTING_EPOCH = 0

In [3]:
def setup_resnet50(num_classes=10, freeze=True):
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    
    if num_classes is not None:
        print(f"Modifying last layer for {num_classes} classes")
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    if freeze:
        # freeze layers for finetuning
        for param in model.parameters():
            param.requires_grad = False

        # unfreeze last layer to finetune
        for param in model.fc.parameters():
            param.requires_grad = True
        
        for param in model.layer4.parameters():
            param.requires_grad = True
            
    return model

In [4]:
# Load the model
model = setup_resnet50(None)

if STARTING_EPOCH > 0:
    model.load_state_dict(torch.load(f"models/{DATASET}/resnet50_{'gt' if GROUND_TRUTH else 'wtsg'}_epoch{STARTING_EPOCH}.pth"))

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 329MB/s]


In [5]:
train_set = torch.load(f'./data/{DATASET}-split/weak-labeled-half.pth')
train_loader = DataLoader(train_set, batch_size=128, shuffle=False)

In [6]:
# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [7]:
# Training Loop setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [8]:
# Training Loop
model.train()

NUM_EPOCHS = 10  # Fine-tuning for 3 epochs

for epoch in range(STARTING_EPOCH, NUM_EPOCHS):
    cum_loss = 0.0
    total = 0
    
    progress_bar = tqdm(train_loader, leave=True)
    progress_bar.set_description(f"Epoch [{epoch+1}/{NUM_EPOCHS}] ({'GT' if GROUND_TRUTH else 'WTSG'})")
    for idx, (images, true_labels, weak_labels) in enumerate(progress_bar):
        images, true_labels, weak_labels = images.to(device), true_labels.to(device), weak_labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        if GROUND_TRUTH:
            # training step for ground truth model
            loss = criterion(outputs, true_labels)
        else:
            # training step for weak labels model
            loss = criterion(outputs, weak_labels)
    
        loss.backward()
        optimizer.step()
        
        cum_loss += loss.item()
        total += len(true_labels)
        
        avg_loss = cum_loss / total
        
        # Update the tqdm bar with loss and epoch
        progress_bar.set_postfix({"avg_loss":avg_loss, "cum_loss":cum_loss})

Epoch [1/10] (GT): 100%|██████████| 313/313 [08:30<00:00,  1.63s/it, avg_loss=0.00894, cum_loss=358]
Epoch [2/10] (GT): 100%|██████████| 313/313 [08:04<00:00,  1.55s/it, avg_loss=0.00318, cum_loss=127] 
Epoch [3/10] (GT): 100%|██████████| 313/313 [07:59<00:00,  1.53s/it, avg_loss=0.00108, cum_loss=43.3]
Epoch [4/10] (GT): 100%|██████████| 313/313 [07:18<00:00,  1.40s/it, avg_loss=0.000419, cum_loss=16.8]
Epoch [5/10] (GT): 100%|██████████| 313/313 [08:39<00:00,  1.66s/it, avg_loss=0.00021, cum_loss=8.41] 
Epoch [6/10] (GT): 100%|██████████| 313/313 [08:32<00:00,  1.64s/it, avg_loss=0.00013, cum_loss=5.22] 
Epoch [7/10] (GT): 100%|██████████| 313/313 [09:05<00:00,  1.74s/it, avg_loss=9.11e-5, cum_loss=3.65] 
Epoch [8/10] (GT): 100%|██████████| 313/313 [08:38<00:00,  1.66s/it, avg_loss=6.93e-5, cum_loss=2.77] 
Epoch [9/10] (GT): 100%|██████████| 313/313 [08:46<00:00,  1.68s/it, avg_loss=5.48e-5, cum_loss=2.19] 
Epoch [10/10] (GT): 100%|██████████| 313/313 [08:23<00:00,  1.61s/it, avg_los

In [9]:
# Save the fine-tuned model
name = f"resnet50_{'gt' if GROUND_TRUTH else 'wtsg'}_epoch10.pth"
torch.save(model.state_dict(), f'models/{DATASET}/{name}')

In [10]:
# model.load_state_dict(torch.load(f"models/{DATASET}/fine_tuned_resnet50_{'gt' if GROUND_TRUTH else 'wtsg'}_epoch10.pth"))

In [11]:
# load the test data
# Define preprocessing transformations
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),  # All 3 models expect 224x224 images
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # normalization constants for ImageNet-1k (pre-training data)
# ])

# test_data = CIFAR100(root='./data', train=False, transform=transform)
test_data = torch.load(f'./data/{DATASET}-split/test.pth')
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

In [12]:
model.eval()

# Initialize metrics
correct = 0
total = 0
test_loss = 0.0

with torch.no_grad():  # Disable gradient computation to save memory and computation
    progress_bar = tqdm(test_loader, leave=True)
    for idx, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device) # send data to gpu
        
        # Forward pass
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        test_loss += loss.item()  # Accumulate loss
        
        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total += len(labels)
        correct += (predicted == labels).sum().item()
        
        # Calculate metrics for display
        avg_loss = test_loss / total
        accuracy = 100 * correct / total

        # Update the tqdm bar with loss and accuracy
        progress_bar.set_postfix(loss=avg_loss, accuracy=accuracy)

# Calculate final metrics
avg_loss = test_loss / len(test_loader)
accuracy = 100 * correct / total

print(f"Test Loss: {avg_loss:.4f}")
print(f"Test Accuracy: {accuracy:.2f}%")

100%|██████████| 79/79 [01:55<00:00,  1.46s/it, accuracy=74.1, loss=0.00887]

Test Loss: 1.1227
Test Accuracy: 74.09%



