In [39]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
from tqdm import tqdm
import torchvision.models as models

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [41]:
class SimCLRTransform:
    def __init__(self, size=32):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.ToTensor()
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)

In [42]:
train_dataset = torchvision.datasets.CIFAR10(root='./', train=True, download=True,
                                              transform=SimCLRTransform())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2)

In [43]:
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, zis, zjs):
        """
        zis, zjs: [batch_size, embedding_dim]
        """
        batch_size = zis.size(0)

        # Concatenate and normalize
        z = torch.cat([zis, zjs], dim=0)
        z = F.normalize(z, dim=1)

        # Cosine similarity matrix
        sim = torch.matmul(z, z.T)  # [2*batch_size, 2*batch_size]
        sim = sim / self.temperature

        # Positive pairs
        sim_i_j = torch.diag(sim, batch_size)
        sim_j_i = torch.diag(sim, -batch_size)
        positives = torch.cat([sim_i_j, sim_j_i], dim=0)  # [2*batch_size]

        # Mask to remove self-comparisons
        N = 2 * batch_size
        mask = (~torch.eye(N, N, dtype=bool)).to(z.device)

        # Get negatives
        negatives = sim[mask].reshape(N, -1)  # [2*batch_size, 2*batch_size - 1]

        # Logits: positive in first column
        logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)  # [2*batch_size, 1 + neg_count]
        labels = torch.zeros(N, dtype=torch.long).to(z.device)  # Always index 0 is the positive

        loss = F.cross_entropy(logits, labels)
        return loss


In [44]:
model = models.resnet18(pretrained=False).to(device)
model.fc = nn.Identity()  # Remove final classification layer

optimizer = optim.SGD(model.parameters(), lr=0.03, momentum=0.9, weight_decay=1e-4)
batch_size = 256
criterion = NTXentLoss(temperature=0.5)


In [45]:
for epoch in range(10):  # Adjust epochs as needed
    total_loss = 0
    for (x1, x2), _ in train_loader:
        x1, x2 = x1.to(device), x2.to(device)
        z1 = model(x1)
        z2 = model(x2)
    
        loss = criterion(z1, z2)
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch [{epoch+1}/10], Loss: {total_loss/len(train_loader):.4f}")

Epoch [1/10], Loss: 5.5659
Epoch [2/10], Loss: 5.3917
Epoch [3/10], Loss: 5.3353
Epoch [4/10], Loss: 5.2963
Epoch [5/10], Loss: 5.2633
Epoch [6/10], Loss: 5.2472
Epoch [7/10], Loss: 5.2268
Epoch [8/10], Loss: 5.2132
Epoch [9/10], Loss: 5.1972
Epoch [10/10], Loss: 5.1871


In [46]:
model.eval()
for param in model.parameters():
    param.requires_grad = False


In [47]:
# Use standard transforms (NO SimCLRTransform)
test_transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.CIFAR10(root='./', train=True, download=True, transform=test_transform)
test_dataset = torchvision.datasets.CIFAR10(root='./', train=False, download=True, transform=test_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)


In [48]:
classifier = nn.Linear(512, 10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=1e-3)


In [49]:
model.eval()  # Encoder stays frozen
classifier.train()

for epoch in range(10):
    total_loss = 0
    correct = 0
    total = 0
    for x, labels in tqdm(train_loader):
        x = x.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            embeddings = model(x)

        outputs = classifier(embeddings)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    acc = 100. * correct / total
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {acc:.2f}%")


100%|██████████| 196/196 [00:04<00:00, 44.26it/s]


Epoch 1, Loss: 1.7498, Accuracy: 36.33%


100%|██████████| 196/196 [00:04<00:00, 42.83it/s]


Epoch 2, Loss: 1.6848, Accuracy: 38.88%


100%|██████████| 196/196 [00:04<00:00, 42.36it/s]


Epoch 3, Loss: 1.6706, Accuracy: 39.28%


100%|██████████| 196/196 [00:04<00:00, 42.53it/s]


Epoch 4, Loss: 1.6544, Accuracy: 40.21%


100%|██████████| 196/196 [00:04<00:00, 45.78it/s]


Epoch 5, Loss: 1.6471, Accuracy: 40.29%


100%|██████████| 196/196 [00:04<00:00, 45.02it/s]


Epoch 6, Loss: 1.6392, Accuracy: 40.74%


100%|██████████| 196/196 [00:04<00:00, 44.18it/s]


Epoch 7, Loss: 1.6343, Accuracy: 40.88%


100%|██████████| 196/196 [00:04<00:00, 44.30it/s]


Epoch 8, Loss: 1.6303, Accuracy: 40.94%


100%|██████████| 196/196 [00:04<00:00, 43.83it/s]


Epoch 9, Loss: 1.6218, Accuracy: 41.28%


100%|██████████| 196/196 [00:04<00:00, 41.15it/s]

Epoch 10, Loss: 1.6202, Accuracy: 41.51%





In [50]:
classifier.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, labels in tqdm(test_loader):
        x = x.to(device)
        labels = labels.to(device)

        embeddings = model(x)
        outputs = classifier(embeddings)

        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

accuracy = 100. * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

100%|██████████| 40/40 [00:01<00:00, 37.12it/s]

Test Accuracy: 41.18%



