In [14]:
import numpy as np
import torch
import importlib
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from scipy.stats import wasserstein_distance
from torch.optim.lr_scheduler import StepLR
from geomloss import SamplesLoss

In [3]:
import h5py
with h5py.File('usps.h5', 'r') as hf:
        train_usps = hf.get('train')
        X_train_usps = train_usps.get('data')[:]
        y_train_usps = train_usps.get('target')[:]
        test_usps = hf.get('test')
        X_test_usps = test_usps.get('data')[:]
        y_test_usps = test_usps.get('target')[:]

In [4]:
def read_mnist(X_train_path, y_train_path, X_test_path, y_test_path):
    with open(X_train_path, 'rb') as f:
        # Skip the magic number and dimension info
        f.read(16)
        X_train_mnist = np.fromfile(f, dtype=np.uint8).reshape(-1, 1, 28, 28)

    with open(y_train_path, 'rb') as f:
        # Skip the magic number and dimension info
        f.read(8)
        y_train_mnist = np.fromfile(f, dtype=np.uint8)
        
    with open(X_test_path, 'rb') as f:
        # Skip the magic number and dimension info
        f.read(16)
        X_test_mnist = np.fromfile(f, dtype=np.uint8).reshape(-1, 1, 28, 28)
        
    with open(y_test_path, 'rb') as f:
        # Skip the magic number and dimension info
        f.read(8)
        y_test_mnist = np.fromfile(f, dtype=np.uint8)

    return X_train_mnist, y_train_mnist, X_test_mnist, y_test_mnist

# Load MNIST data
X_train_path = "mnist/mnist/archive/train-images-idx3-ubyte/train-images-idx3-ubyte"
y_train_path = "mnist/mnist/archive/train-labels-idx1-ubyte/train-labels-idx1-ubyte"
X_test_path = "mnist/mnist/archive/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte"
y_test_path = "mnist/mnist/archive/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte"
X_train_mnist, y_train_mnist, X_test_mnist, y_test_mnist = read_mnist(X_train_path, y_train_path, X_test_path, y_test_path)

In [5]:
# Convert the datasets to PyTorch tensors
X_train_mnist = torch.tensor(X_train_mnist).float()
y_train_mnist = torch.tensor(y_train_mnist).long()
X_test_mnist = torch.tensor(X_test_mnist).float()
y_test_mnist = torch.tensor(y_test_mnist).long()

# Ensure the data is reshaped correctly (no extra singleton dimension)
X_train_mnist = X_train_mnist.reshape(-1, 1, 28, 28)  # MNIST images are 28x28
X_test_mnist = X_test_mnist.reshape(-1, 1, 28, 28)  # MNIST images are 28x28


# Normalize the datasets
mnist_transform = transforms.Compose([
    transforms.Normalize((0.5,), (0.5,))
])

X_train_mnist = mnist_transform(X_train_mnist)
X_test_mnist = mnist_transform(X_test_mnist)

# Create data loaders
batch_size = 64
train_dataset_mnist = TensorDataset(X_train_mnist, y_train_mnist)
train_loader_mnist = DataLoader(train_dataset_mnist, shuffle=True)

test_dataset_mnist = TensorDataset(X_test_mnist, y_test_mnist)
test_loader_mnist = DataLoader(test_dataset_mnist, shuffle=True)

In [6]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class USPSDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]

        # Convert array to PIL Image
        image = Image.fromarray(image.squeeze(), mode='L')

        # Apply the transform
        if self.transform:
            image = self.transform(image)

        # Convert label to tensor
        label = torch.tensor(label, dtype=torch.long)

        return image, label

usps_transform = transforms.Compose([
    transforms.Resize((28, 28)),  # Resize images to 28x28
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
# Recreate the USPS dataset and dataloader
train_dataset_usps = USPSDataset(X_train_usps, y_train_usps, transform=usps_transform)
train_loader_usps = DataLoader(train_dataset_usps, shuffle=True)

test_dataset_usps = USPSDataset(X_test_usps, y_train_usps, transform=usps_transform)
test_loader_usps = DataLoader(test_dataset_usps, shuffle=True)

In [8]:
class FeatureExtractorCNN(nn.Module):
    def __init__(self):
        super(FeatureExtractorCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(3 * 3 * 128, 1024)
        self.fc2 = nn.Linear(1024, 10) 
        self.dropout = nn.Dropout(p=0.2) 

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = self.dropout(x) 
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = self.dropout(x) 
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 3 * 3 * 128) 
        x = self.dropout(x)  
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  
        x = self.fc2(x)
        return x


def compute_wasserstein_distance(source_features, target_features):
    source_features_flat = source_features.reshape(source_features.shape[0], -1)
    target_features_flat = target_features.reshape(target_features.shape[0], -1)
    
    wd = 0
    for i in range(source_features_flat.shape[1]):
        wd += wasserstein_distance(source_features_flat[:, i], target_features_flat[:, i])
    return wd / source_features_flat.shape[1]


def train(model, source_loader, target_loader, num_epochs, criterion, weight_wasserstein, loss_function_str=None):
    model.train()
    feature_optimizer = optim.AdamW(model.parameters(), lr=2e-5)
    scheduler = StepLR(feature_optimizer, step_size=10, gamma=0.1)
    
    if loss_function_str == 'wasserstein':
        wasserstein_loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)
    
    for epoch in range(num_epochs):
        for (source_data, source_labels), (target_data, _) in zip(source_loader, target_loader):
            feature_optimizer.zero_grad()

            source_features = model(source_data)
            target_features = model(target_data)

            classification_loss = criterion(source_features, source_labels)
            wd_loss = 0

            if loss_function_str == 'wasserstein':
                wd_loss = wasserstein_loss(source_features, target_features)

            total_loss = classification_loss + weight_wasserstein * wd_loss
            total_loss.backward()
            feature_optimizer.step()
        
        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs}, LR: {scheduler.get_last_lr()} completed.")


In [11]:
feature_extractor = FeatureExtractorCNN()

# Hyperparameters
num_epochs = 50
loss_function_str = 'wasserstein'  # This is just a string to control the use of Wasserstein loss
criterion = nn.CrossEntropyLoss()  # Loss function object
weight_wasserstein = 0.25

# Start the training process
train(feature_extractor, train_loader_mnist, train_loader_usps, num_epochs,
      criterion, weight_wasserstein, loss_function_str)


Epoch 1/50, LR: [2e-05] completed.
Epoch 2/50, LR: [2e-05] completed.
Epoch 3/50, LR: [2e-05] completed.
Epoch 4/50, LR: [2e-05] completed.
Epoch 5/50, LR: [2e-05] completed.
Epoch 6/50, LR: [2e-05] completed.
Epoch 7/50, LR: [2e-05] completed.
Epoch 8/50, LR: [2e-05] completed.
Epoch 9/50, LR: [2e-05] completed.
Epoch 10/50, LR: [2.0000000000000003e-06] completed.
Epoch 11/50, LR: [2.0000000000000003e-06] completed.
Epoch 12/50, LR: [2.0000000000000003e-06] completed.
Epoch 13/50, LR: [2.0000000000000003e-06] completed.
Epoch 14/50, LR: [2.0000000000000003e-06] completed.
Epoch 15/50, LR: [2.0000000000000003e-06] completed.
Epoch 16/50, LR: [2.0000000000000003e-06] completed.
Epoch 17/50, LR: [2.0000000000000003e-06] completed.
Epoch 18/50, LR: [2.0000000000000003e-06] completed.
Epoch 19/50, LR: [2.0000000000000003e-06] completed.
Epoch 20/50, LR: [2.0000000000000004e-07] completed.
Epoch 21/50, LR: [2.0000000000000004e-07] completed.
Epoch 22/50, LR: [2.0000000000000004e-07] complet

In [12]:
def evaluate_model(model, test_loader, device='cpu'):
    model.eval()
    model.to(device)
    all_preds = []
    all_labels = []
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            if data.shape[1] != 1:
                raise ValueError(f"Source data should have 1 channel, got {data.shape[1]}")

            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the model on the test images: {accuracy}%')

    return all_labels, all_preds

# Evaluate on MNIST Test Set
mnist_test_labels, mnist_test_preds = evaluate_model(feature_extractor, test_loader_mnist)

# Evaluate on USPS Test Set
usps_test_labels, usps_test_preds = evaluate_model(feature_extractor, test_loader_usps)

Accuracy of the model on the test images: 98.58%
Accuracy of the model on the test images: 7.872446437468859%
