In [None]:
# 2.4 Visualizing the STL-10 Dataset Step-by-Step
# First, let's load and visualize STL-10:

import numpy as np

import torchvision

import matplotlib.pyplot as plt

from torchvision.datasets import STL10



# Load STL-10 dataset

dataset = STL10(root="./data", split="train", download=True)


 55%|█████▌    | 1.46G/2.64G [17:20<15:42, 1.25MB/s]

In [None]:
# Now, let's display 10 random images:

fig, axes = plt.subplots(2, 5, figsize=(10, 5))

for i, ax in enumerate(axes.flat):

    img, label = dataset[i]

    ax.imshow(np.array(img).transpose(1, 2, 0))

    ax.set_title(f"Class: {dataset.classes[label]}")

    ax.axis("off")

plt.show()


In [None]:
# To better understand the dataset, we visualize sample images:

import numpy as np
import torchvision
import matplotlib.pyplot as plt
from torchvision.datasets import STL10

dataset = STL10(root="./data", split="train", download=True)
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(axes.flat):
    img, label = dataset[i]
    ax.imshow(np.array(img).transpose(1, 2, 0))
    ax.set_title(f"Class: {dataset.classes[label]}")
    ax.axis("off")
plt.show()


In [None]:
# Example Code: Data Augmentation for Contrastive Learning

from torchvision import transforms

# Define transformations used in SimCLR
simclr_transforms = transforms.Compose([
    transforms.RandomResizedCrop(96),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor()
])


In [None]:
# Step 1: Load Data with Augmentations

from torch.utils.data import DataLoader

train_loader = DataLoader(STL10(root="./data", split="train", transform=simclr_transforms, download=True), batch_size=256, shuffle=True)

In [None]:
#Step 2: Define a Feature Extractor (ResNet-18)

import torch import torchvision.models as models

class SimCLRFeatureExtractor(torch.nn.Module): def init(self): super(SimCLRFeatureExtractor, self).init() self.encoder = models.resnet18(pretrained=True) self.encoder.fc = torch.nn.Identity() # Remove classification head

def forward(self, x):
    return self.encoder(x)


In [None]:
# Step 3: Implementing the Projection Head

class ProjectionHead(torch.nn.Module): def init(self, input_dim=512, output_dim=128): super(ProjectionHead, self).init() self.fc1 = torch.nn.Linear(input_dim, 256) self.fc2 = torch.nn.Linear(256, output_dim)

def forward(self, x):
    x = torch.nn.functional.relu(self.fc1(x))
    x = self.fc2(x)
    return x

In [None]:
from torchvision import transforms

simclr_transforms = transforms.Compose([
    transforms.RandomResizedCrop(96),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor()
])

# This ensures strong augmentation to generate positive pairs for contrastive learning.