<a href="https://colab.research.google.com/github/hongqin/Python-CoLab-bootcamp/blob/master/contrast_learning_ministipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# generated by perplexity

# Contrastive Learning Tutorial for Beginners
# This tutorial demonstrates a basic implementation of contrastive learning using the MNIST dataset.

# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Set up data transformations
# We convert the images to tensors and normalize them
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize with MNIST mean and std
])

# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Define the neural network architecture
class ContrastiveNet(nn.Module):
    def __init__(self):
        super(ContrastiveNet, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, 3, 1)  # 1 input channel, 32 output channels, 3x3 kernel, stride 1
        self.conv2 = nn.Conv2d(32, 64, 3, 1)  # 32 input channels, 64 output channels, 3x3 kernel, stride 1
        # Fully connected layer
        self.fc = nn.Linear(9216, 128)  # 9216 input features, 128 output features

    def forward(self, x):
        # Apply convolutions with ReLU activation
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        # Apply max pooling
        x = nn.functional.max_pool2d(x, 2)
        # Flatten the tensor
        x = torch.flatten(x, 1)
        # Apply fully connected layer
        x = self.fc(x)
        # Normalize the output
        return nn.functional.normalize(x, dim=1)

# Create an instance of the model
model = ContrastiveNet()

# Define the contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features):
        batch_size = features.shape[0]
        # Create labels for the positive pairs
        labels = torch.arange(batch_size).to(features.device)

        # Compute similarity matrix
        similarity_matrix = torch.matmul(features, features.T)

        # Create a mask to identify positive pairs
        mask = torch.eye(batch_size, dtype=torch.bool).to(features.device)

        # Extract positive and negative pairs
        positives = similarity_matrix[mask].view(batch_size, -1)
        negatives = similarity_matrix[~mask].view(batch_size, -1)

        # Concatenate positive and negative similarities
        logits = torch.cat([positives, negatives], dim=1)

        # Create labels for the contrastive loss
        labels = torch.zeros(batch_size, dtype=torch.long).to(features.device)

        # Compute the contrastive loss
        loss = nn.functional.cross_entropy(logits / self.temperature, labels)
        return loss

# Create instances of the loss function and optimizer
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define the training function
def train(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            optimizer.zero_grad()  # Reset gradients
            features = model(data)  # Forward pass
            loss = criterion(features)  # Compute loss
            loss.backward()  # Backward pass
            optimizer.step()  # Update weights
            total_loss += loss.item()

        # Print the average loss for the epoch
        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}')

# Train the model
train(model, train_loader, criterion, optimizer)

# Function to visualize the learned embeddings
def visualize_embeddings(model, dataset, num_samples=1000):
    model.eval()
    embeddings = []
    labels = []

    # Generate embeddings for a subset of the dataset
    with torch.no_grad():
        for i in range(num_samples):
            img, label = dataset[i]
            embedding = model(img.unsqueeze(0))
            embeddings.append(embedding.squeeze().numpy())
            labels.append(label)

    embeddings = torch.tensor(embeddings)
    labels = torch.tensor(labels)

    # Plot the embeddings
    plt.figure(figsize=(10, 8))
    for i in range(10):
        idx = labels == i
        plt.scatter(embeddings[idx, 0], embeddings[idx, 1], label=str(i))
    plt.legend()
    plt.title('2D Visualization of Learned Representations')
    plt.show()

# Visualize the embeddings
visualize_embeddings(model, train_dataset)

# Explanation of the Contrastive Learning process:
# 1. We define a neural network that learns to map images to a 128-dimensional embedding space.
# 2. The contrastive loss encourages similar images (same digit) to have similar embeddings,
#    and different images (different digits) to have dissimilar embeddings.
# 3. During training, we compute the similarity between all pairs of images in a batch.
# 4. The loss function tries to maximize the similarity of positive pairs (same digit)
#    while minimizing the similarity of negative pairs (different digits).
# 5. After training, we visualize the learned embeddings to see how well the model
#    has separated different digits in the embedding space.

# This tutorial provides a basic introduction to contrastive learning.
# In practice, more advanced techniques like data augmentation, larger models,
# and more sophisticated loss functions are often used to achieve better results.

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 28477333.76it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 3906720.86it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 11589280.24it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3018144.61it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Epoch 1/5, Loss: 2.3663
