In [1]:
pip install torch torchvision

Note: you may need to restart the kernel to use updated packages.


In [1]:
import pandas as pd
import numpy as np

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.models import googlenet
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

# Step 1: Select a style transfer model and style dataset
style_model = googlenet(weights=True)  # Pretrained GoogLeNet model

# Step 2: Choose a dataset
# Assuming you have downloaded and extracted the TinyImageNet dataset

# Step 3: Define a custom dataset for style transfer
class StyleTransferDataset(Dataset):
    def __init__(self, transform=None):
        self.data = torchvision.datasets.ImageFolder('C:\\Users\\90538\\Desktop\\data\\tiny-imagenet-200\\train_try\\n01443537', transform=transform)

    def __getitem__(self, index):
        img, _ = self.data[index]
        return img

    def __len__(self):
        return len(self.data)

# Step 4: Define contrastive learning method
def contrastive_loss(image1, image2, temperature=0.5):
    # Enable requires_grad flag for the tensors
    image1.requires_grad_()
    image2.requires_grad_()

    # Normalize the image tensors
    image1 = F.normalize(image1, dim=1)
    image2 = F.normalize(image2, dim=1)

    # Calculate cosine similarity
    similarity = F.cosine_similarity(image1, image2, dim=1) / temperature

    # Calculate contrastive loss
    loss = -torch.log(torch.exp(similarity).sum() / torch.exp(similarity).diag().sum())

    return loss


# Step 5: Define a linear classifier on top of the trained model
class LinearClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LinearClassifier, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)

    def forward(self, x):
        x = self.fc(x)
        return x

# Step 6: Prepare the data
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

dataset = StyleTransferDataset(transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Step 7: Train the style transfer model with contrastive learning
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
style_model.to(device)
style_model.eval()

contrastive_optimizer = optim.Adam(style_model.parameters(), lr=0.001)

for epoch in range(10):
    for images in dataloader:
        images = images.to(device)

        # Generate positive and negative samples
        with torch.no_grad():
            features = style_model(images)
        features = F.normalize(features, dim=1)

        positive_samples = features
        negative_samples = features[torch.randperm(features.size(0))]

        # Calculate contrastive loss
        contrastive_loss_value = contrastive_loss(positive_samples, negative_samples)

        # Optimize the contrastive model
        contrastive_optimizer.zero_grad()
        contrastive_loss_value.backward()
        contrastive_optimizer.step()

# Step 8: Freeze the trained model and train a linear classifier on top
contrastive_model = style_model
contrastive_model.eval()

linear_classifier = LinearClassifier(1024, num_classes=200)  # TinyImageNet has 200 classes
linear_classifier.to(device)

linear_classifier_optimizer = optim.Adam(linear_classifier.parameters(), lr=0.001)

for epoch in range(10):
    for images in dataloader:
        images = images.to(device)

        with torch.no_grad():
            features = contrastive_model(images)
        features = F.normalize(features, dim=1)

        # Train the linear classifier
        linear_classifier_optimizer.zero_grad()
        outputs = linear_classifier(features)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        linear_classifier_optimizer.step()

# Step 9: Calculate the performance of the model
# You can use a separate validation dataset or perform cross-validation to evaluate the performance of the model.

# Step 10: Compare the results with the literature
