In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
transform = transforms.Compose([
    transforms.Resize((10,10)),
    transforms.ToTensor()
])

In [3]:
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
cifar10_train[0][0].shape

torch.Size([3, 10, 10])

In [5]:
class CPDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        img = img.permute(1, 2, 0).numpy() * 255 
        img = img.astype(np.uint8)  

        patches = []
        for i in range(3):
            for j in range(3):
                patch = img[10*i:10*(i+1), 10*j:10*(j+1), :]
                patch = Image.fromarray(patch)  
                if self.transform:
                    patch = self.transform(patch)
                patches.append(patch)

        idx1, idx2 = np.random.choice(9, 2, replace=False)
        label = self.get_spatial_label(idx1, idx2)

        return patches[idx1], patches[idx2], label

    def get_spatial_label(self, idx1, idx2):
        rel_positions = {
            (-1, -1): 0, (-1, 0): 1, (-1, 1): 2,
            (0, -1): 3, (0, 1): 4,
            (1, -1): 5, (1, 0): 6, (1, 1): 7
        }
        row1, col1 = divmod(idx1, 3)
        row2, col2 = divmod(idx2, 3)
        
        return rel_positions.get((row2 - row1, col2 - col1), -1)

    def __len__(self):
        return len(self.dataset)

In [6]:
train_dataset = CPDataset(cifar10_train, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Model

In [12]:
class ContextPrediction(nn.Module):
    def __init__(self, num_classes=8):
        super().__init__()

        self.shared_cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), # 10x10
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), # 4x4
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2,2))
        )

        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, patch1, patch2):
        x1 = self.shared_cnn(patch1)
        x1 = x1.view(x1.size(0), -1)
        x2 = self.shared_cnn(patch2)
        x2 = x2.view(x2.size(0), -1)
        x = torch.cat((x1, x2), dim=1)
        x = self.fc(x)
        return x

In [13]:
import torch.optim as optim

In [14]:
device = 'cuda:3'

model = ContextPrediction().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-1)  # ❌ -1 레이블을 무시

In [19]:
def train():
    model.train()
    for epoch in range(20):
        total_loss = 0
        for patch1, patch2, labels in train_loader:
            patch1, patch2, labels = patch1.to(device), patch2.to(device), labels.to(device)
            optimizer.zero_grad()
            
            output = model(patch1, patch2)
            loss = criterion(output, labels)
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{20}], Loss: {total_loss/len(train_loader):.4f}")            

In [20]:
train()

Epoch [1/20], Loss: 1.9208
Epoch [2/20], Loss: 1.9196
Epoch [3/20], Loss: 1.9146
Epoch [4/20], Loss: 1.9212
Epoch [5/20], Loss: 1.9208
Epoch [6/20], Loss: 1.9171
Epoch [7/20], Loss: 1.9165
Epoch [8/20], Loss: 1.9223
Epoch [9/20], Loss: 1.9138
Epoch [10/20], Loss: 1.9132
Epoch [11/20], Loss: 1.9222
Epoch [12/20], Loss: 1.9149
Epoch [13/20], Loss: 1.9197
Epoch [14/20], Loss: 1.9171
Epoch [15/20], Loss: 1.9145
Epoch [16/20], Loss: 1.9166
Epoch [17/20], Loss: 1.9149
Epoch [18/20], Loss: 1.9153
Epoch [19/20], Loss: 1.9148
Epoch [20/20], Loss: 1.9196
