ConvNext Task 2

In [None]:
import pickle
import warnings
warnings.filterwarnings("ignore")
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import timm  # For ConvNeXt models
import numpy as np
from sklearn.metrics import accuracy_score

batch_size = 16

# Load ConvNeXt model which gives 768 features
model = timm.create_model('convnext_small', pretrained=True)
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])  # Remove the classification head
feature_extractor.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

transform = transforms.Compose([
    transforms.ToPILImage(),  # Convert tensor to PIL image
    transforms.Resize((224, 224)),  # Resize images to 224x224 (ConvNeXt default input size)
    transforms.ToTensor(),  # Convert PIL image back to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize
])

#This is a class defined to create the dataset
class UnlabeledDataset(torch.utils.data.Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        img = self.transform(img)
        return img

#Lwp classifier class
class LWPClassifier:
    def __init__(self, num_classes=2):
        self.num_classes = num_classes
        self.prototypes = None  # Prototypes will be initialized during training

    def fit(self, X, y):
        # Initialize prototypes as the mean of features for each class
        self.prototypes = np.array([X[y == c].mean(axis=0) for c in range(self.num_classes)])

    def predict(self, X):
        # Compute distances to prototypes
        distances = np.linalg.norm(X[:, np.newaxis] - self.prototypes, axis=2)  # Shape: [num_samples, num_classes]
        return np.argmin(distances, axis=1)  # Assign to the nearest prototype

# This function uses the pre trained model to extract features in batches and return them
def get_features(images, transform):
    dataset = UnlabeledDataset(images, transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    feature_extractor.eval()
    features = []

    with torch.no_grad():
        for batch_images in loader:
            batch_images = batch_images.to(device)
            batch_features = feature_extractor(batch_images)
            batch_features = batch_features.mean([2, 3])  # Global Average Pooling
            features.append(batch_features.cpu())

    # Combine features into a single tensor
    features = torch.cat(features)
    return features

# Load the f10 classifier from the file
with open("lwp_classifier.pkl", "rb") as file:
    f = pickle.load(file)

all_eval_features = []
all_eval_labels = []
acc_matrix = np.zeros((10, 10))
data_paths = [f"dataset/part_two_dataset/train_data/{i}_train_data.tar.pth" for i in range(1, 11)]
heldout_paths = [f"dataset/part_two_dataset/eval_data/{i}_eval_data.tar.pth" for i in range(1, 11)]

# Main training and evaluation loop
for loop in range(0,10):
    data = torch.load(data_paths[loop])
    heldout = torch.load(heldout_paths[loop])

    train_images = data['data']
    eval_images = heldout['data']
    eval_labels = heldout['targets']

    feature_extractor = feature_extractor.to(device)

    # Extract training features, labels are predicted using previous model
    train_features = get_features(train_images, transform)
    print(f"D{loop+11} training features extracted")
    train_labels = f.predict(train_features)

    # Fit the classifier
    f.fit(train_features, train_labels)

    # Extract evaluation features
    eval_features = get_features(eval_images, transform)
    print(f"D{loop+11} eval features extracted")
    all_eval_features.append(eval_features)
    all_eval_labels.append(eval_labels)

    # Compute accuracies
    accuracies = []
    for j in range(0,loop):
        preds = f.predict(all_eval_features[j])
        acc = accuracy_score(all_eval_labels[j], preds)
        accuracies.append(acc)
        acc_matrix[loop][j] = acc

    pred_labels = f.predict(eval_features)
    acc = accuracy_score(eval_labels, pred_labels)
    accuracies.append(acc)
    acc_matrix[loop][loop] = acc
    print(f"f{loop+11} Test Accuracies: {accuracies}\n")

print("Accuracy matrix\n")
print(acc_matrix)

model.safetensors:   0%|          | 0.00/201M [00:00<?, ?B/s]

D11 training features extracted
D11 eval features extracted
f11 Test Accuracies: [0.7684]

D12 training features extracted
D12 eval features extracted
f12 Test Accuracies: [0.7336, 0.6276]

D13 training features extracted
D13 eval features extracted
f13 Test Accuracies: [0.7476, 0.6064, 0.8192]

D14 training features extracted
D14 eval features extracted
f14 Test Accuracies: [0.7616, 0.6008, 0.8244, 0.9032]

D15 training features extracted
D15 eval features extracted
f15 Test Accuracies: [0.7508, 0.5928, 0.8232, 0.9036, 0.9056]

D16 training features extracted
D16 eval features extracted
f16 Test Accuracies: [0.7488, 0.5912, 0.8244, 0.8864, 0.8956, 0.7688]

D17 training features extracted
D17 eval features extracted
f17 Test Accuracies: [0.7416, 0.588, 0.8124, 0.8836, 0.8872, 0.7708, 0.8264]

D18 training features extracted
D18 eval features extracted
f18 Test Accuracies: [0.7464, 0.594, 0.8232, 0.88, 0.886, 0.7592, 0.8268, 0.7948]

D19 training features extracted
D19 eval features ext