In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.datasets as datasets
from transformers import ViTModel
from torch.utils.data import DataLoader

# 1. SimCLR projection head
class SimCLRHead(nn.Module):
    def __init__(self, in_dim, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, proj_dim)
        )

    def forward(self, x):
        return F.normalize(self.net(x), dim=1)

# 2. SimCLR model with Hugging Face ViT
class SimCLRViT(nn.Module):
    def __init__(self, vit_name='google/vit-base-patch16-224'):
        super().__init__()
        self.vit = ViTModel.from_pretrained(vit_name)
        self.projector = SimCLRHead(self.vit.config.hidden_size)

    def forward(self, x):
        out = self.vit(pixel_values=x)
        cls = out.last_hidden_state[:, 0]
        return self.projector(cls)

# 3. NT-Xent contrastive loss
def nt_xent_loss(z1, z2, temperature=0.5):
    N = z1.size(0)
    z = torch.cat([z1, z2], dim=0)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) / temperature
    mask = torch.eye(2 * N, dtype=torch.bool).to(z.device)
    sim.masked_fill_(mask, -9e15)
    pos = torch.cat([torch.arange(N, 2*N), torch.arange(N)]).to(z.device)
    numerator = torch.exp(sim[torch.arange(2 * N), pos])
    denominator = torch.exp(sim).sum(dim=1)
    return -torch.log(numerator / denominator).mean()

# 4. SimCLR augmentations (two views per image)
simclr_transform = T.Compose([
    T.Resize(224),
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.8, 0.8, 0.8, 0.2),
    T.RandomGrayscale(p=0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

class SimCLRDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        x, _ = self.dataset[index]
        return self.transform(x), self.transform(x)

    def __len__(self):
        return len(self.dataset)

# 5. Load real dataset (e.g., CIFAR-10)
base_dataset = datasets.CIFAR10(root='./data', download=True, train=True)
contrastive_dataset = SimCLRDataset(base_dataset, simclr_transform)
loader = DataLoader(contrastive_dataset, batch_size=16, shuffle=True)

# 6. Train one epoch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimCLRViT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
for batch in loader:
    x1, x2 = batch[0].to(device), batch[1].to(device)
    z1 = model(x1)
    z2 = model(x2)
    loss = nt_xent_loss(z1, z2)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Loss: {loss.item():.4f}")
     

Files already downloaded and verified


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loss: 3.2618
Loss: 3.2524
Loss: 3.0144
Loss: 3.1063
Loss: 2.9531
Loss: 2.9929
Loss: 2.7704
Loss: 2.9567
Loss: 2.8871
Loss: 2.7814
Loss: 2.6348
Loss: 2.7831
Loss: 2.6347
Loss: 2.5610
Loss: 2.3411
Loss: 3.1918
Loss: 2.6296
Loss: 2.8086
Loss: 2.7367
Loss: 2.5065
Loss: 2.6690
Loss: 2.6929
Loss: 2.9132
Loss: 2.3690
Loss: 2.4448
Loss: 2.6789
Loss: 2.5385
Loss: 2.4750
Loss: 2.5035
Loss: 2.4985
Loss: 2.7757
Loss: 2.5425
Loss: 2.5418
Loss: 2.5142
Loss: 2.5408
Loss: 2.6757
Loss: 2.2696
Loss: 2.2776
Loss: 2.5645
Loss: 2.7233
Loss: 2.5119
Loss: 2.4778
Loss: 2.6486
Loss: 2.3645
Loss: 2.5653
Loss: 2.3759
Loss: 2.3978
Loss: 2.3867
Loss: 2.6243
Loss: 2.3129
Loss: 2.2547
Loss: 2.7386
Loss: 2.3332
Loss: 2.5635
Loss: 2.4412
Loss: 2.7544
Loss: 2.4257
Loss: 2.3533
Loss: 2.5612
Loss: 2.4589
Loss: 2.2187
Loss: 2.3256
Loss: 2.3865
Loss: 2.2964
Loss: 2.6938
Loss: 2.3360
Loss: 2.4444
Loss: 2.5634
Loss: 2.5381
Loss: 2.2305
Loss: 2.3927
Loss: 2.3927
Loss: 2.6843
Loss: 2.7041
Loss: 2.5965
Loss: 2.7713
Loss: 2.3335

In [None]:
#Usa il modello appreso come embedding e addestra un classificatore supervisionato
# in esempi reali si usa DB unlabeled per la prima parte e DB labeled per il cassificatore

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 1. Estrai solo l'encoder (senza projector)
encoder = model.vit  # assuming model = your trained SimCLRModel
encoder.eval()       # non addestriamo più l'encoder

# 2. Classificatore lineare
class LinearClassifier(nn.Module):
    def __init__(self, in_dim, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

classifier = LinearClassifier(in_dim=encoder.config.hidden_size)
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 3. Trasformazioni standard per classificazione
transform_eval = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# 4. Dataset etichettato
train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_eval)
test_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_eval)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64)

# 5. Training loop del classificatore
for epoch in range(5):
    classifier.train()
    total_loss = 0
    for x, y in train_loader:
        with torch.no_grad():
            features = encoder(pixel_values=x).last_hidden_state[:, 0]  # CLS token

        logits = classifier(features)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")
