In [1]:
from torchvision import models
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
from torchvision.models import resnet18
import torch.nn.functional as F
from pytorch_lightning import seed_everything
import random
from sklearn.metrics import classification_report

num_anchors = 300
hidden_features = 512
seed_everything(1)
one_dimentional_data = False

# Step 1: Define the transformations for the training and test sets
transform = transforms.Compose([
    transforms.Grayscale(3) if one_dimentional_data else transforms.Lambda(lambda x: x),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Step 2: Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# Step 3: select and partition anchors
def get_anchors_idx(num_anchors: int, max_anchors_len: int):
    assert num_anchors <= max_anchors_len
    anchor_idxs = list(range(max_anchors_len))
    random.shuffle(anchor_idxs)
    anchor_idxs = anchor_idxs[:num_anchors]
    return anchor_idxs

anchors_idx = get_anchors_idx(num_anchors, max_anchors_len=len(trainset))
# Extract subset of training dataset using the list of indices
anchors_set = torch.utils.data.Subset(trainset, anchors_idx)
anchorloader = torch.utils.data.DataLoader(anchors_set, batch_size=num_anchors, shuffle=False, num_workers=2)

class RelRepBlock(nn.Module):
    def __init__(self, normalization, num_anchors):
        super(RelRepBlock, self).__init__()
        self.num_anchors = num_anchors
        self.normalization = normalization
        if normalization:
            self.outnorm = nn.LayerNorm(normalized_shape=num_anchors)

    def forward(self, x: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:
        x = F.normalize(x, p=2, dim=-1)
        anchors = F.normalize(anchors, p=2, dim=-1)
        relative_reps = torch.einsum("nd, ad -> na", x, anchors)

        if self.normalization:
            return self.outnorm(relative_reps)
        else:
            return relative_reps

# Step 4: Initialize ResNet-18 model and modify it
class RelativeResNet18(nn.Module):
    def __init__(self, num_anchors, hidden_features, fine_tune=False):
        super(RelativeResNet18, self).__init__()
        # Remove the last fully connected layer
        self.resnet = resnet18(pretrained=True)
        self.resnet18_fc_shape = self.resnet.fc.in_features
        self.resnet.fc = nn.Identity()
        
        if not fine_tune:
            for param in self.resnet.parameters():
                param.requires_grad = False
                param.grad = None
            
        self.post_resnet_nn = nn.Sequential(
            nn.Linear(in_features=self.resnet18_fc_shape, out_features=self.resnet18_fc_shape),
            nn.BatchNorm1d(num_features=self.resnet18_fc_shape),
            nn.Tanh(),
            nn.Linear(in_features=self.resnet18_fc_shape, out_features=hidden_features),
            nn.Tanh(),
        )
        # Relative Representation Transform
        self.relative_transform = RelRepBlock(normalization=True, num_anchors=num_anchors)
        self.output_layer = nn.Linear(num_anchors, 10)

    def forward(self, x, anchors):
        with torch.no_grad():
            anchors_latents = self.resnet(anchors)
            anchors_latents = self.post_resnet_nn(anchors_latents)
            x_latents = self.resnet(x)
        x_latents = self.post_resnet_nn(x_latents)
        relative_reps = self.relative_transform(x_latents, anchors_latents)
        return self.output_layer(relative_reps)


# Initialize the modified model
model = RelativeResNet18(hidden_features=hidden_features, num_anchors=num_anchors)

# If using a GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Step 5: Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# Step 6: Train the model
def train(epoch):
    model.train()
    running_loss = 0.0
    anchors, _ = next(iter(anchorloader))
    anchors = anchors.to(device)
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs, anchors)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # print every 100 mini-batches
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

# Step 7: Evaluate the model
def test():
    model.eval()
    all_labels = []
    all_predictions = []
    anchors, _ = next(iter(anchorloader))
    anchors = anchors.to(device)
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images, anchors)
            _, predicted = torch.max(outputs.data, 1)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    print(classification_report(all_labels, all_predictions))

# Training and Testing Loop
for epoch in range(10):  # number of epochs
    train(epoch)
    test()

print('Finished Training')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 80393027.56it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 141MB/s]


[Epoch 1, Batch 100] loss: 1.122
[Epoch 1, Batch 200] loss: 0.804
[Epoch 1, Batch 300] loss: 0.775
[Epoch 1, Batch 400] loss: 0.757
[Epoch 1, Batch 500] loss: 0.760
[Epoch 1, Batch 600] loss: 0.702
[Epoch 1, Batch 700] loss: 0.697
[Epoch 1, Batch 800] loss: 0.696
[Epoch 1, Batch 900] loss: 0.689
[Epoch 1, Batch 1000] loss: 0.702
[Epoch 1, Batch 1100] loss: 0.712
[Epoch 1, Batch 1200] loss: 0.660
[Epoch 1, Batch 1300] loss: 0.684
[Epoch 1, Batch 1400] loss: 0.672
[Epoch 1, Batch 1500] loss: 0.675
              precision    recall  f1-score   support

           0       0.79      0.81      0.80      1000
           1       0.88      0.87      0.87      1000
           2       0.71      0.74      0.72      1000
           3       0.75      0.51      0.61      1000
           4       0.76      0.75      0.75      1000
           5       0.68      0.81      0.74      1000
           6       0.76      0.88      0.81      1000
           7       0.86      0.77      0.82      1000
           8

In [2]:
from torchvision import models
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
from torchvision.models import resnet18
import torch.nn.functional as F
from pytorch_lightning import seed_everything
import random
from sklearn.metrics import classification_report


hidden_features = 512
seed_everything(1)
one_dimentional_data = False

# Step 1: Define the transformations for the training and test sets
transform = transforms.Compose([
    transforms.Grayscale(3) if one_dimentional_data else transforms.Lambda(lambda x: x),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])


# Step 2: Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


# Step 4: Initialize ResNet-18 model and modify it
class ResNet18(nn.Module):
    def __init__(self, num_anchors, hidden_features, fine_tune=False):
        super(ResNet18, self).__init__()
        # Remove the last fully connected layer
        self.resnet = resnet18(pretrained=True)
        self.resnet18_fc_shape = self.resnet.fc.in_features
        self.resnet.fc = nn.Identity()
        
        if not fine_tune:
            for param in self.resnet.parameters():
                param.requires_grad = False
                param.grad = None
            
        self.output_layer = nn.Linear(in_features=hidden_features, out_features=10)
        self.post_resnet_nn = nn.Sequential(
            nn.Linear(in_features=self.resnet18_fc_shape, out_features=self.resnet18_fc_shape),
            nn.BatchNorm1d(num_features=self.resnet18_fc_shape),
            nn.Tanh(),
            nn.Linear(in_features=self.resnet18_fc_shape, out_features=hidden_features),
            nn.Tanh(),
        )

    def forward(self, x):
        with torch.no_grad():
            x_latents = self.resnet(x)
        x_latents = self.post_resnet_nn(x_latents)
        return self.output_layer(x_latents)


# Initialize the modified model
model = ResNet18(hidden_features=hidden_features, num_anchors=num_anchors)

# If using a GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Step 5: Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# Step 6: Train the model
def train(epoch):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # print every 100 mini-batches
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

# Step 7: Evaluate the model
def test():
    model.eval()
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    print(classification_report(all_labels, all_predictions))

# Training and Testing Loop
for epoch in range(10):  # number of epochs
    train(epoch)
    test()

print('Finished Training')

Files already downloaded and verified
Files already downloaded and verified




[Epoch 1, Batch 100] loss: 0.945
[Epoch 1, Batch 200] loss: 0.776
[Epoch 1, Batch 300] loss: 0.769
[Epoch 1, Batch 400] loss: 0.742
[Epoch 1, Batch 500] loss: 0.737
[Epoch 1, Batch 600] loss: 0.694
[Epoch 1, Batch 700] loss: 0.714
[Epoch 1, Batch 800] loss: 0.701
[Epoch 1, Batch 900] loss: 0.693
[Epoch 1, Batch 1000] loss: 0.693
[Epoch 1, Batch 1100] loss: 0.732
[Epoch 1, Batch 1200] loss: 0.709
[Epoch 1, Batch 1300] loss: 0.649
[Epoch 1, Batch 1400] loss: 0.657
[Epoch 1, Batch 1500] loss: 0.677
              precision    recall  f1-score   support

           0       0.79      0.81      0.80      1000
           1       0.91      0.84      0.87      1000
           2       0.79      0.58      0.67      1000
           3       0.61      0.71      0.66      1000
           4       0.82      0.67      0.74      1000
           5       0.72      0.75      0.74      1000
           6       0.73      0.91      0.81      1000
           7       0.79      0.80      0.80      1000
           8