**Introduction**

Sample reweighting is a common technique in domain adaptation, where the goal is to handle distribution shifts between a source domain (used for training) and a target domain (used for deployment). The idea is to assign weights to source domain samples to make them more representative of the target domain. This way, the model can better adapt to the target domain's characteristics.

**Imports**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler

from torchvision import datasets, transforms


import matplotlib.pyplot as plt



import kagglehub

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Download latest version
path = kagglehub.dataset_download("bistaumanga/usps-dataset")

print("Path to dataset files:", path)

Path to dataset files: /root/.cache/kagglehub/datasets/bistaumanga/usps-dataset/versions/1


In [None]:
# Transform for normalization
transform = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0.5,), (0.5,))])

# Download and load MNIST dataset
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

100%|██████████| 9.91M/9.91M [00:00<00:00, 53.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.71MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.6MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.04MB/s]


**Data Processing**

Set Up Sample Weights for Reweighting: You could calculate weights using a domain classifier, Kernel Density Estimation, or other metrics to align MNIST samples closer to USPS.



In [None]:
# Example: Random weights for MNIST samples
mnist_weights = torch.rand(len(mnist_train))

# Normalize weights
mnist_weights /= mnist_weights.sum()

In [None]:
mnist_weights

tensor([3.5041e-06, 3.2884e-05, 2.3945e-05,  ..., 3.8541e-06, 4.4937e-06,
        2.8864e-05])

In [None]:
sampler = WeightedRandomSampler(weights=mnist_weights, num_samples=len(mnist_weights), replacement=True)
mnist_loader = DataLoader(mnist_train, sampler=sampler, batch_size=64)

**Model**

In [None]:
# Dummy model (binary classification for simplicity)
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(28 * 28, 1)  # Input size for flattened images (28x28)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        return self.fc(x)


model = SimpleModel().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

**Train Loop**

In [None]:
# Training Loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()

    total_loss = 0.0
    for mnist_batch in mnist_loader:  # Training on MNIST (source domain)
        inputs, labels = mnist_batch
        inputs, labels = inputs.to(device), labels.float().to(device)

        # Example sample weights (random for now, replace with real weights)
        sample_weights = torch.rand(len(labels)).to(device)

        # Forward pass
        logits = model(inputs)
        loss = criterion(logits.squeeze(), labels)

        # Apply sample weights
        weighted_loss = (loss * sample_weights).mean()

        # Backward pass and optimization
        optimizer.zero_grad()
        weighted_loss.backward()
        optimizer.step()

        total_loss += weighted_loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}")

    # Evaluate on USPS (target domain)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for usps_batch in usps_loader:
            inputs, labels = usps_batch
            inputs, labels = inputs.to(device), labels.float().to(device)

            logits = model(inputs)
            predictions = (torch.sigmoid(logits.squeeze()) > 0.5).float()
            correct += (predictions == labels).sum().item()
            total += len(labels)

    accuracy = correct / total * 100
    print(f"Evaluation Accuracy on USPS: {accuracy:.2f}%")
