# Contrastive learning을 활용한 SSL(Self-Supervised Learning)

아래 세 그룹을 비교하여 SSL의 성능을 확인합니다.

1. Full data를 활용한 지도학습
2. 데이터를 일부만 활용한 지도학습
3. Contrastive learning으로 Pre-training + 데이터의 일부만 추출한 뒤 Fine-tuning 

## Prep

In [None]:
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import transforms
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"{device} ready!")

### 네트워크 설계

In [2]:
class SimCLRNet(nn.Module):
    def __init__(self, projection_dim=128, cls=False):
        super(SimCLRNet, self).__init__()
        self.cls = cls
        # 인코더 정의
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), 
            
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU()
        )
        # 프로젝션 헤드
        if self.cls == False:
            self.projection_head = nn.Sequential(
                nn.Linear(256, 64),
                nn.ReLU(),
                nn.Linear(64, projection_dim)
            )
        else:    
            self.classifier = nn.Sequential(
                nn.Linear(256, 10)
            )

    def forward(self, x):
        features = self.encoder(x)
        if self.cls == False:
            projections = self.projection_head(features)
            return projections
        else:
            logits = self.classifier(features)
            return logits

In [None]:
simclrnet = SimCLRNet(cls=False)
simclrnet

In [4]:
cnn_1 = SimCLRNet(cls=True) # Full data Supervised learning
cnn_2 = SimCLRNet(cls=True) # Partial data Supervised learning

In [None]:
dummy_input = torch.randn(1, 3, 32, 32)
dummy_output_simclr = simclrnet(dummy_input)
dummy_output_cnn = cnn_1(dummy_input)

print(dummy_output_simclr.shape, dummy_output_cnn.shape)

In [None]:
# 기본 변환
transform_cnn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# CIFAR-10 데이터셋 로드
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cnn)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cnn)

total_length = len(train_dataset)
train_length = int(len(train_dataset) * 0.8)

train_subset, val_subset = random_split(train_dataset, [train_length, total_length - train_length])

train_loader = DataLoader(train_subset, batch_size=128, shuffle=True)

val_loader = DataLoader(val_subset, batch_size=128, shuffle=False)

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# CIFAR-10 클래스 (라벨)
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
print(f"Train size: {len(train_subset)}\nVal size: {len(val_subset)}\nTest size: {len(test_dataset)}")

# 비교 실험

## Vanilla CNN - Full data supervised learning

In [None]:
CE_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_1.parameters(), lr=3e-4, weight_decay=0.001)
cnn_1.to(device)

In [9]:
def supervised_train_validate(model, train_loader, val_loader, optimizer, criterion, epochs, device):
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    model.to(device)
    
    for epoch in range(epochs):
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        model.train()
        
        for x, y in tqdm(train_loader):
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * x.size(0)
            train_total += y.size(0)
            train_correct += (pred.argmax(1) == y).sum().item()

        train_loss /= train_total
        train_acc = train_correct / train_total
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        print(f"Epoch {epoch+1}\nTrain Loss: {train_loss:.4f}\tTrain Accuracy: {train_acc:.4f}")


        val_loss = 0.0
        val_correct = 0
        val_total = 0
        model.eval()
        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(device)
                y = y.to(device)
                
                pred = model(x)
                loss = criterion(pred, y)
            
                val_loss += loss.item() * x.size(0)
                val_total += y.size(0)
                val_correct += (pred.argmax(1) == y).sum().item()
                
        val_loss /= val_total
        val_acc = val_correct / val_total
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f"Val Loss: {val_loss:.4f}\tVal Accuracy: {val_acc:.4f}")
        
    return train_losses, train_accs, val_losses, val_accs

In [None]:
train_losses, train_accs, val_losses, val_accs = supervised_train_validate(cnn_1, train_loader, test_loader, optimizer, CE_loss, 10, device)

In [None]:
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.legend()
plt.show()
plt.plot(train_accs, label='train')
plt.plot(val_accs, label='val')
plt.legend()
plt.show()

## Supervised learning with partial data only

In [None]:
CE_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_2.parameters(), lr=3e-4, weight_decay=0.001)
cnn_2.to(device)

In [None]:
part_train_losses, part_train_accs, part_val_losses, part_val_accs = supervised_train_validate(cnn_2, val_loader, test_loader, optimizer, CE_loss, 10, device)

In [None]:
plt.plot(part_train_losses, label='train')
plt.plot(part_val_losses, label='val')
plt.legend()
plt.show()
plt.plot(part_train_accs, label='train')
plt.plot(part_val_accs, label='val')
plt.legend()
plt.show()

## Pre-training & Finetuning

### Pre-training with Contrastive learning

In [14]:
transform_simclr = transforms.Compose([
    transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class CIFAR10SimCLR(torchvision.datasets.CIFAR10):
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        
        x_1 = self.transform(img)
        x_2 = self.transform(img)

        return x_1, x_2, target

In [None]:
simclr_dataset = CIFAR10SimCLR(root='./data', train=True, download=True, transform=transform_simclr)

simclr_loader = DataLoader(simclr_dataset, batch_size=128, shuffle=True)

In [None]:
for i in range(10):
    x1, x2, y = simclr_dataset[i]
    print(x1.shape, x2.shape, y)

    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    x1 = (x1.permute(1, 2, 0) * 0.5 + 0.5).numpy()
    x2 = (x2.permute(1, 2, 0) * 0.5 + 0.5).numpy()
    
    axes[0].imshow(x1)
    axes[0].axis('off')
    axes[0].set_title(classes[y])
    axes[1].imshow(x2)
    axes[1].axis('off')
    axes[1].set_title(classes[y])

    plt.show()  # 그래프 표시

In [17]:
# NT-Xent Loss 정의
def nt_xent_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.shape[0]
    z = torch.cat([z_i, z_j], dim=0)
    
    # 벡터 정규화 (각 벡터의 길이를 1로 맞춤)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
    sim_exp = torch.exp(sim / temperature)
    
    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
    sim_exp = sim_exp.masked_fill(mask, 0)
    
    pos_sim = torch.exp(F.cosine_similarity(z_i, z_j) / temperature)
    loss = -torch.log(pos_sim / (sim_exp.sum(dim=1)[:batch_size] + sim_exp.sum(dim=1)[batch_size:]))
    return loss.mean()

In [18]:
def pretrain_simclr(model, dataloader, optimizer, epochs=10, temperature=0.5):
    model.to(device)
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for x_1, x_2, _ in tqdm(dataloader):
            x_1 = x_1.to(device)
            x_2 = x_2.to(device)
            
            # 두 이미지의 특징 벡터 추출
            z_i = model(x_1)
            z_j = model(x_2)

            # NT-Xent 손실 계산
            loss = nt_xent_loss(z_i, z_j, temperature)

            # 옵티마이저 업데이트
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}] Loss: {total_loss/len(dataloader):.4f}")
    
    return model

In [19]:
simclr_optim = optim.Adam(simclrnet.parameters(), lr=0.001)

In [None]:
pretrain_simclr(simclrnet, simclr_loader, simclr_optim, epochs=10, temperature=0.5)

### Fine-tuning

In [21]:
# Fine-tuning을 위한 분류기로 전환
finetune_model = SimCLRNet(cls=True).to(device)

# 인코더 가중치 복사 (사전 학습된 인코더 사용)
finetune_model.encoder.load_state_dict(simclrnet.encoder.state_dict())

# 옵티마이저와 손실 함수 설정
finetune_optim = optim.Adam(finetune_model.parameters(), lr=3e-4, weight_decay=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
simclr_train_losses, simclr_train_accs, simclr_val_losses, simclr_val_accs = supervised_train_validate(finetune_model, val_loader, test_loader, finetune_optim, criterion, epochs=10, device=device)

In [None]:
plt.plot(simclr_train_losses, label='simclr_train')
plt.plot(simclr_val_losses, label='simclr_val')
plt.legend()
plt.show()

plt.plot(simclr_train_accs, label='simclr_train')
plt.plot(simclr_val_accs, label='simclr_val')
plt.legend()
plt.show()

## 결과 비교

In [None]:
plt.title('Validation Loss comparison')
plt.plot(val_losses, label='Full data Supervised Learning')
plt.plot(part_val_losses, label='Partial data Supervised Learning')
plt.plot(simclr_val_losses, label='SimCLR')
plt.legend()
plt.show()

plt.title('Validation Accuracy comparison')
plt.plot(val_accs, label='Full data Supervised Learning')
plt.plot(part_val_accs, label='Partial data Supervised Learning')
plt.plot(simclr_val_accs, label='SimCLR')
plt.legend()
plt.show()