In [1]:
! pip install idx2numpy



In [2]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import idx2numpy
import requests
import gzip
import numpy as np
from io import BytesIO
from torchvision import transforms

# Custom dataset class for FashionMNIST with normalization
class FashionImagesDataset(Dataset):
    def __init__(self, img_url, lbl_url, transform=None):
        self.img_data = self.fetch_and_extract(img_url)
        self.lbl_data = self.fetch_and_extract(lbl_url)
        self.transform = transform

    def fetch_and_extract(self, url):
        response = requests.get(url)
        response.raise_for_status()
        with gzip.GzipFile(fileobj=BytesIO(response.content)) as f:
            return idx2numpy.convert_from_file(f)

    def __len__(self):
        return len(self.img_data)

    def __getitem__(self, idx):
        img = self.img_data[idx].reshape(1, 28, 28) / 255.0  # Normalize to [0,1]
        img = torch.tensor(img, dtype=torch.float32)
        label = self.lbl_data[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# URLs for dataset
train_imgs_url = "https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/train-images-idx3-ubyte.gz"
train_lbls_url = "https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/train-labels-idx1-ubyte.gz"
test_imgs_url = "https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/t10k-images-idx3-ubyte.gz"
test_lbls_url = "https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/t10k-labels-idx1-ubyte.gz"

# Define transformations
transform = transforms.Normalize((0.5,), (0.5,))  # Normalizing images to mean 0.5 and std 0.5

# Load the data with transformations
train_data = FashionImagesDataset(train_imgs_url, train_lbls_url, transform=transform)
test_data = FashionImagesDataset(test_imgs_url, test_lbls_url, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

# Define a simpler CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Define training and evaluation functions
def train_model(model, train_loader, test_loader, epochs=10, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Evaluate the model after each epoch
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

        accuracy = 100 * correct / total
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}, Accuracy: {accuracy:.2f}%')

    # Save the model
    torch.save(model.state_dict(), "/content/simple_cnn_model.pt")
    print("Model saved to /content/simple_cnn_model.pt")

# Instantiate and train the simpler model
model = SimpleCNN()
train_model(model, train_loader, test_loader, epochs=10, lr=0.001)


Epoch [1/10], Loss: 0.5135, Accuracy: 86.73%
Epoch [2/10], Loss: 0.3274, Accuracy: 89.13%
Epoch [3/10], Loss: 0.2827, Accuracy: 90.20%
Epoch [4/10], Loss: 0.2524, Accuracy: 90.46%
Epoch [5/10], Loss: 0.2283, Accuracy: 90.99%
Epoch [6/10], Loss: 0.2084, Accuracy: 90.99%
Epoch [7/10], Loss: 0.1897, Accuracy: 91.89%
Epoch [8/10], Loss: 0.1736, Accuracy: 91.91%
Epoch [9/10], Loss: 0.1602, Accuracy: 91.88%
Epoch [10/10], Loss: 0.1461, Accuracy: 92.00%
Model saved to /content/simple_cnn_model.pt


In [5]:
import requests
import torch

# URL for the model checkpoint on GitHub
model_url = "https://github.com/clionmuhoza/assignment_3/raw/main/simple_cnn_model.pt"

# Load the model from a checkpoint on GitHub
def load_checkpoint_from_url(model_class, url, lr=0.001):
    # Download the model checkpoint from GitHub
    response = requests.get(url)
    response.raise_for_status()
    with open("downloaded_checkpoint.pt", "wb") as f:
        f.write(response.content)

    # Load the checkpoint and inspect its structure
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load("downloaded_checkpoint.pt", map_location=device)

    # Print the keys to see what is available in the checkpoint
    print("Checkpoint keys:", checkpoint.keys())

    # Instantiate the model and optimizer
    model = model_class().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Check if 'model_state_dict' key is in the checkpoint
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        print(f"Model loaded from {url}, resuming from epoch {start_epoch}")
    else:
        # If checkpoint only has model weights (no optimizer or epoch info)
        model.load_state_dict(checkpoint)
        start_epoch = 0
        print(f"Model weights loaded from {url} without optimizer or epoch info. Starting from epoch {start_epoch}")

    return model, optimizer, start_epoch

# Example
# Load the model, optimizer, and starting epoch from GitHub
model, optimizer, start_epoch = load_checkpoint_from_url(SimpleCNN, model_url, lr=0.001)


Checkpoint keys: odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
Model weights loaded from https://github.com/clionmuhoza/assignment_3/raw/main/simple_cnn_model.pt without optimizer or epoch info. Starting from epoch 0


  checkpoint = torch.load("downloaded_checkpoint.pt", map_location=device)
