<a href="https://colab.research.google.com/github/manuaishika/mu/blob/data/cmu_stn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

# Data loading with error handling
try:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
except Exception as e:
    print(f"Error loading data: {e}")
    exit(1)

# Network definition
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        return x

    def forward(self, x):
        x = self.stn(x)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# Device setup with fallback
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Training with progress bar and validation
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)

def train(epoch):
    model.train()
    total_loss = 0
    with tqdm(train_loader, desc=f"Epoch {epoch}", unit="batch") as t:
        for batch_idx, (data, target) in enumerate(t):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            t.set_postfix(loss=loss.item())
            if batch_idx % 100 == 0 and batch_idx > 0:
                print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                      f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    return total_loss / len(train_loader)

def validate():
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += F.nll_loss(output, target).item()
    return total_loss / len(test_loader)

for epoch in range(1, 11):
    train_loss = train(epoch)
    val_loss = validate()
    print(f'Epoch {epoch} completed - Train Loss: {train_loss:.6f}, Validation Loss: {val_loss:.6f}')

# Visualization with saved output
def visualize_stn():
    model.eval()
    with torch.no_grad():
        data, _ = next(iter(test_loader))
        data = data.to(device)
        transformed_data = model.stn(data)
        data = data.cpu().numpy()
        transformed_data = transformed_data.cpu().numpy()
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
        ax1.imshow(data[0, 0], cmap='gray')
        ax1.set_title('Original')
        ax1.set_xlabel('X')
        ax1.set_ylabel('Y')
        ax2.imshow(transformed_data[0, 0], cmap='gray')
        ax2.set_title('Transformed')
        ax2.set_xlabel('X')
        ax2.set_ylabel('Y')
        plt.savefig('stn_visualization.png')
        plt.close()

visualize_stn()
print("Visualization saved as 'stn_visualization.png'")

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.47MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 161kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.51MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.11MB/s]


Using device: cpu


Epoch 1:  11%|█         | 102/938 [00:05<00:55, 15.12batch/s, loss=2.18]



Epoch 1:  22%|██▏       | 204/938 [00:12<00:39, 18.60batch/s, loss=1.76]



Epoch 1:  32%|███▏      | 304/938 [00:17<00:33, 18.71batch/s, loss=1.49]



Epoch 1:  43%|████▎     | 404/938 [00:23<00:28, 18.82batch/s, loss=1.17]



Epoch 1:  54%|█████▎    | 504/938 [00:29<00:22, 18.94batch/s, loss=0.792]



Epoch 1:  64%|██████▍   | 604/938 [00:35<00:17, 18.92batch/s, loss=0.735]



Epoch 1:  75%|███████▌  | 704/938 [00:40<00:12, 19.10batch/s, loss=0.649]



Epoch 1:  86%|████████▌ | 803/938 [00:46<00:08, 16.49batch/s, loss=0.458]



Epoch 1:  96%|█████████▋| 903/938 [00:52<00:01, 19.01batch/s, loss=0.508]



Epoch 1: 100%|██████████| 938/938 [00:53<00:00, 17.41batch/s, loss=0.5]


Epoch 1 completed - Train Loss: 1.178340, Validation Loss: 0.204800


Epoch 2:  11%|█         | 104/938 [00:05<00:45, 18.42batch/s, loss=0.426]



Epoch 2:  22%|██▏       | 203/938 [00:11<00:55, 13.24batch/s, loss=0.305]



Epoch 2:  32%|███▏      | 303/938 [00:16<00:32, 19.42batch/s, loss=0.434]



Epoch 2:  43%|████▎     | 404/938 [00:22<00:29, 18.21batch/s, loss=0.523]



Epoch 2:  54%|█████▎    | 504/938 [00:28<00:23, 18.26batch/s, loss=0.451]



Epoch 2:  64%|██████▍   | 604/938 [00:34<00:19, 17.49batch/s, loss=0.363]



Epoch 2:  75%|███████▍  | 703/938 [00:40<00:12, 18.44batch/s, loss=0.344]



Epoch 2:  86%|████████▌ | 803/938 [00:45<00:07, 18.72batch/s, loss=0.286]



Epoch 2:  96%|█████████▋| 904/938 [00:51<00:01, 18.68batch/s, loss=0.299]



Epoch 2: 100%|██████████| 938/938 [00:53<00:00, 17.59batch/s, loss=0.277]


Epoch 2 completed - Train Loss: 0.397518, Validation Loss: 0.118346


Epoch 3:  11%|█         | 104/938 [00:06<00:48, 17.09batch/s, loss=0.225]



Epoch 3:  22%|██▏       | 203/938 [00:12<00:39, 18.46batch/s, loss=0.389]



Epoch 3:  32%|███▏      | 303/938 [00:18<00:41, 15.45batch/s, loss=0.185]



Epoch 3:  43%|████▎     | 404/938 [00:23<00:27, 19.35batch/s, loss=0.415]



Epoch 3:  54%|█████▎    | 503/938 [00:29<00:33, 12.95batch/s, loss=0.307]



Epoch 3:  64%|██████▍   | 603/938 [00:35<00:17, 18.66batch/s, loss=0.481]



Epoch 3:  75%|███████▍  | 703/938 [00:40<00:12, 19.20batch/s, loss=0.309] 



Epoch 3:  86%|████████▌ | 804/938 [00:46<00:07, 19.10batch/s, loss=0.303]



Epoch 3:  96%|█████████▋| 903/938 [00:52<00:01, 19.17batch/s, loss=0.177]



Epoch 3: 100%|██████████| 938/938 [00:54<00:00, 17.34batch/s, loss=0.151]


Epoch 3 completed - Train Loss: 0.286038, Validation Loss: 0.090195


Epoch 4:  11%|█         | 104/938 [00:05<00:44, 18.68batch/s, loss=0.11]



Epoch 4:  22%|██▏       | 203/938 [00:11<00:38, 19.17batch/s, loss=0.135]



Epoch 4:  32%|███▏      | 303/938 [00:16<00:33, 19.14batch/s, loss=0.16] 



Epoch 4:  43%|████▎     | 404/938 [00:23<00:35, 14.89batch/s, loss=0.277]



Epoch 4:  54%|█████▎    | 504/938 [00:28<00:22, 19.21batch/s, loss=0.223]



Epoch 4:  64%|██████▍   | 603/938 [00:33<00:25, 13.10batch/s, loss=0.136]



Epoch 4:  75%|███████▌  | 704/938 [00:39<00:12, 19.07batch/s, loss=0.256]



Epoch 4:  86%|████████▌ | 804/938 [00:44<00:06, 19.22batch/s, loss=0.241]



Epoch 4:  96%|█████████▋| 903/938 [00:51<00:01, 18.86batch/s, loss=0.33] 



Epoch 4: 100%|██████████| 938/938 [00:52<00:00, 17.72batch/s, loss=0.216]


Epoch 4 completed - Train Loss: 0.234601, Validation Loss: 0.079374


Epoch 5:  11%|█         | 103/938 [00:06<00:44, 18.88batch/s, loss=0.111]



Epoch 5:  22%|██▏       | 204/938 [00:11<00:38, 18.91batch/s, loss=0.182]



Epoch 5:  32%|███▏      | 304/938 [00:17<00:33, 18.98batch/s, loss=0.373]



Epoch 5:  43%|████▎     | 403/938 [00:22<00:28, 18.66batch/s, loss=0.114]



Epoch 5:  54%|█████▎    | 503/938 [00:28<00:35, 12.27batch/s, loss=0.249]



Epoch 5:  64%|██████▍   | 604/938 [00:34<00:18, 18.44batch/s, loss=0.145]



Epoch 5:  75%|███████▍  | 702/938 [00:39<00:17, 13.72batch/s, loss=0.244] 



Epoch 5:  86%|████████▌ | 804/938 [00:45<00:06, 19.46batch/s, loss=0.273]



Epoch 5:  96%|█████████▋| 904/938 [00:51<00:01, 18.50batch/s, loss=0.198]



Epoch 5: 100%|██████████| 938/938 [00:53<00:00, 17.57batch/s, loss=0.37]


Epoch 5 completed - Train Loss: 0.214977, Validation Loss: 0.086309


Epoch 6:  11%|█         | 104/938 [00:05<00:43, 19.03batch/s, loss=0.142]



Epoch 6:  22%|██▏       | 204/938 [00:11<00:38, 19.12batch/s, loss=0.157]



Epoch 6:  32%|███▏      | 303/938 [00:16<00:34, 18.29batch/s, loss=0.092]



Epoch 6:  43%|████▎     | 403/938 [00:23<00:28, 18.81batch/s, loss=0.155]



Epoch 6:  54%|█████▎    | 504/938 [00:28<00:22, 19.25batch/s, loss=0.0945]



Epoch 6:  64%|██████▍   | 602/938 [00:34<00:27, 12.25batch/s, loss=0.127]



Epoch 6:  75%|███████▍  | 703/938 [00:39<00:12, 18.78batch/s, loss=0.133]



Epoch 6:  86%|████████▌ | 803/938 [00:45<00:10, 13.46batch/s, loss=0.109]



Epoch 6:  96%|█████████▋| 903/938 [00:51<00:01, 18.50batch/s, loss=0.0687]



Epoch 6: 100%|██████████| 938/938 [00:53<00:00, 17.65batch/s, loss=0.179]


Epoch 6 completed - Train Loss: 0.183630, Validation Loss: 0.069201


Epoch 7:  11%|█         | 104/938 [00:06<00:44, 18.72batch/s, loss=0.0518]



Epoch 7:  22%|██▏       | 203/938 [00:11<00:38, 18.87batch/s, loss=0.214]



Epoch 7:  32%|███▏      | 303/938 [00:17<00:34, 18.62batch/s, loss=0.206] 



Epoch 7:  43%|████▎     | 403/938 [00:23<00:29, 18.20batch/s, loss=0.178]



Epoch 7:  54%|█████▎    | 503/938 [00:29<00:23, 18.35batch/s, loss=0.0621]



Epoch 7:  64%|██████▍   | 603/938 [00:34<00:18, 18.47batch/s, loss=0.381]



Epoch 7:  75%|███████▍  | 703/938 [00:41<00:14, 16.42batch/s, loss=0.169]



Epoch 7:  86%|████████▌ | 803/938 [00:46<00:07, 19.01batch/s, loss=0.148]



Epoch 7:  96%|█████████▋| 903/938 [00:52<00:02, 12.59batch/s, loss=0.119]



Epoch 7: 100%|██████████| 938/938 [00:54<00:00, 17.13batch/s, loss=0.0666]


Epoch 7 completed - Train Loss: 0.170091, Validation Loss: 0.064216


Epoch 8:  11%|█         | 102/938 [00:05<01:04, 13.05batch/s, loss=0.0915]



Epoch 8:  22%|██▏       | 203/938 [00:11<00:38, 18.90batch/s, loss=0.0876]



Epoch 8:  32%|███▏      | 303/938 [00:17<00:33, 19.06batch/s, loss=0.209]



Epoch 8:  43%|████▎     | 403/938 [00:23<00:28, 18.78batch/s, loss=0.175]



Epoch 8:  54%|█████▎    | 503/938 [00:28<00:23, 18.89batch/s, loss=0.248]



Epoch 8:  64%|██████▍   | 603/938 [00:34<00:17, 18.89batch/s, loss=0.156]



Epoch 8:  75%|███████▍  | 703/938 [00:40<00:12, 18.45batch/s, loss=0.163]



Epoch 8:  86%|████████▌ | 803/938 [00:46<00:07, 18.54batch/s, loss=0.112]



Epoch 8:  96%|█████████▋| 903/938 [00:52<00:02, 16.24batch/s, loss=0.158]



Epoch 8: 100%|██████████| 938/938 [00:54<00:00, 17.12batch/s, loss=0.163]


Epoch 8 completed - Train Loss: 0.152014, Validation Loss: 0.065385


Epoch 9:  11%|█         | 104/938 [00:05<00:45, 18.51batch/s, loss=0.233]



Epoch 9:  22%|██▏       | 204/938 [00:11<00:41, 17.89batch/s, loss=0.135]



Epoch 9:  32%|███▏      | 304/938 [00:17<00:33, 18.87batch/s, loss=0.328]



Epoch 9:  43%|████▎     | 402/938 [00:23<00:43, 12.24batch/s, loss=0.103]



Epoch 9:  54%|█████▎    | 503/938 [00:28<00:22, 19.01batch/s, loss=0.163]



Epoch 9:  64%|██████▍   | 603/938 [00:34<00:24, 13.45batch/s, loss=0.0362]



Epoch 9:  75%|███████▍  | 703/938 [00:40<00:12, 19.16batch/s, loss=0.137]



Epoch 9:  86%|████████▌ | 803/938 [00:45<00:07, 19.09batch/s, loss=0.156]



Epoch 9:  96%|█████████▋| 903/938 [00:51<00:01, 18.78batch/s, loss=0.149]



Epoch 9: 100%|██████████| 938/938 [00:53<00:00, 17.47batch/s, loss=0.196]


Epoch 9 completed - Train Loss: 0.146734, Validation Loss: 0.053390


Epoch 10:  11%|█         | 104/938 [00:06<00:44, 18.59batch/s, loss=0.0834]



Epoch 10:  22%|██▏       | 204/938 [00:11<00:38, 19.07batch/s, loss=0.0773]



Epoch 10:  32%|███▏      | 304/938 [00:17<00:33, 18.79batch/s, loss=0.282]



Epoch 10:  43%|████▎     | 403/938 [00:23<00:29, 18.14batch/s, loss=0.109]



Epoch 10:  54%|█████▎    | 503/938 [00:29<00:28, 15.45batch/s, loss=1.43]



Epoch 10:  64%|██████▍   | 604/938 [00:34<00:17, 18.80batch/s, loss=0.076]



Epoch 10:  75%|███████▍  | 702/938 [00:40<00:18, 12.98batch/s, loss=0.0639]



Epoch 10:  86%|████████▌ | 804/938 [00:46<00:07, 18.31batch/s, loss=0.0998]



Epoch 10:  96%|█████████▋| 904/938 [00:51<00:01, 18.83batch/s, loss=0.161]



Epoch 10: 100%|██████████| 938/938 [00:54<00:00, 17.18batch/s, loss=0.823]


Epoch 10 completed - Train Loss: 0.139035, Validation Loss: 0.086410
Visualization saved as 'stn_visualization.png'
