In [1]:
import torch
import torchvision.models as models
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
import cv2
import numpy as np

from scipy.spatial.distance import pdist
from sklearn.cluster import KMeans

from pprint import pprint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FEAT_DIM=512

In [2]:
# Define the transformations to apply to the CIFAR-10 data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the image tensors
])

# Define the training and test datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Define the dataloaders to load the data in batches during training and testing
batch_size = 64

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [4]:
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of testing batches: {len(test_loader)}")

Number of training batches: 782
Number of testing batches: 157


In [5]:
# extract vgg features
def get_features(extractor, dataloader):
    
    extractor.eval()
    extractor = extractor.to(device)
    result = None
    for i, (x, y) in enumerate(dataloader):
        output = extractor(x.to(device)).squeeze().detach().cpu().numpy()
        if type(result)==np.ndarray:
            # not empty
            result = np.concatenate([result, output], axis=0)
        else:
            result = output.copy()

    return result

In [6]:
vgg16 = models.vgg16(pretrained=True)
extractor = vgg16.features

feats = get_features(extractor, train_loader)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [7]:
# work with the first 1000 features
# features = feats[:1000]
features = feats

In [8]:
# get pairwise distance
def get_pairwise_distance(features):
    
    dist = pdist(features, metric="euclidean")
    return dist

# dists = get_pairwise_distance(feats)
# print(dists.shape)

In [9]:
# cluster the features
def cluster_features(features, num_clusters=5):
    
    cobj = KMeans(n_clusters=num_clusters)
    
    # do the clustering
    cobj.fit(features)
    
    # get cluster assignments for the feature vectors
    assignments = cobj.labels_
    
    return assignments    

In [10]:
# TODO: experiments to see what number of cluster assignments works well 
NUM_CLUSTERS=5
c_labels = cluster_features(features, num_clusters=NUM_CLUSTERS)

In [11]:
# separate out the data into clusters
from collections import defaultdict

clustered_data = defaultdict(list)

for idx, l in enumerate(c_labels):
    clustered_data[l].append(idx)

In [12]:
# cluster indices and number of samples per cluster
idx_size = []

for l in clustered_data.keys():
    idx_size.append((l, len(clustered_data[l])))

# sort by the number of samples in the cluster
idx_size = sorted(idx_size, key=lambda x: x[1], reverse=True)
print(idx_size)

[(1, 17686), (0, 11627), (3, 11611), (4, 4729), (2, 4347)]


In [19]:
def model_eval(model, dataloader):
    
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [25]:
# the network to be trained
model = models.resnet18(pretrained=False)
num_classes = 10

learning_rate = 1e-3

# Modify the last fully connected layer
fc_input = model.fc.in_features
model.fc = nn.Linear(fc_input, num_classes)

# print(model)

# Step 5: Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [26]:
model = model.to(device)

NUM_EPOCHS=10
MODE="L2S"

# train sequentially on different clusters (larger to smaller)
for c_idx, _ in idx_size:
    
    log_dir = f"./logs/vgg16_{MODE}_EQ{NUM_EPOCHS}_c{NUM_CLUSTERS}_{c_idx}"  # Set the directory for storing the logs
    writer = SummaryWriter(log_dir)
    
    # record the indices
    idx_size_str = ', '.join(str(val) for val in idx_size)
    writer.add_text("cluster sizes", idx_size_str)
    
    # create dataloader for the cluster
    cluster_loader = torch.utils.data.DataLoader(train_dataset, 
                                                 batch_size=8,
                                                 num_workers=2, 
                                                 sampler=SubsetRandomSampler(clustered_data[c_idx]),
                                                 drop_last=True)
    
    # train using this data-loader
    for epoch in range(NUM_EPOCHS):         # TBD: How many epochs per cluster?
        running_loss = 0.0
        for i, data in enumerate(cluster_loader):
            inputs, labels = data[0].to(device), data[1].to(device)
            
            # zero out the gradients
            optimizer.zero_grad()

            # get predictions and compute loss
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # backpropagate
            loss.backward()
            optimizer.step()
            
            # track the loss
            running_loss += loss.item()
        
        # Log the train loss
        epoch_loss = running_loss/len(cluster_loader)
        writer.add_scalar('Loss/train', epoch_loss, epoch+1)
        running_loss = 0.0
        
        # Log the test accuracy
        test_accuracy = model_eval(model, test_loader)
        writer.add_scalar('Accuracy/test', test_accuracy, epoch+1)
    
    writer.close()
    print(f"Cluster {c_idx} ({len(clustered_data[c_idx])/len(train_dataset)*100:.2f}% data) done. Test Acc: {test_accuracy:.3f}")

Cluster 1 (35.37% data) done. Test Acc: 56.960
Cluster 0 (23.25% data) done. Test Acc: 53.510
Cluster 3 (23.22% data) done. Test Acc: 54.800
Cluster 4 (9.46% data) done. Test Acc: 49.780
Cluster 2 (8.69% data) done. Test Acc: 38.860


In [22]:
writer.close()