In [1]:
import torch, torchvision, random, numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import EMNIST
from torchvision import transforms, models
from tqdm.notebook import tqdm
from torchinfo import summary

In [2]:
SEED        = 42
BATCH_SIZE  = 256
EPOCHS      = 25
BASE_LR     = 1e-3
WEIGHT_DECAY= 5e-4
NUM_WORKERS = 2
AMP         = torch.cuda.is_available()

def set_seed(seed):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
set_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device, "AMP:", AMP)

Device: cuda AMP: True


In [3]:
mean = (0.1736,)
std  = (0.3317,)
train_tf = transforms.Compose([
    transforms.RandomCrop(28, padding=2),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

data_root = "./data"
train_set = EMNIST(root=data_root, split="letters", train=True, download=True, transform=train_tf)
test_set  = EMNIST(root=data_root, split="letters", train=False, download=True, transform=test_tf)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train: {len(train_set)}  Test: {len(test_set)}")

100%|██████████| 562M/562M [00:03<00:00, 169MB/s]


Train: 124800  Test: 20800


In [4]:
class SpinalClassifier(nn.Module):
    def __init__(self, in_features: int, hidden: int = 256, segments: int = 4, num_classes: int = 26,
                 dropout: float = 0.30):
        super().__init__()
        seg = in_features // segments
        self.seg = seg
        self.branches = nn.ModuleList()

        for i in range(segments):
            # input dim for this segment
            inp_dim = seg if i == 0 else seg + hidden
            self.branches.append(
                nn.Sequential(
                    nn.Linear(inp_dim, hidden, bias=False),
                    nn.BatchNorm1d(hidden),
                    nn.ReLU(inplace=True),
                    nn.Dropout(dropout),
                    nn.Linear(hidden, hidden),
                    nn.ReLU(inplace=True),
                )
            )

        self.fc_out = nn.Linear(hidden * segments, num_classes)

    def forward(self, x):
        x = x.flatten(1)              
        h_list = []
        for i, branch in enumerate(self.branches):
            start = i * self.seg
            end   = start + self.seg
            seg_x = x[:, start:end]
            if i != 0:
                seg_x = torch.cat([seg_x, h_list[-1]], dim=1)
            h_i = branch(seg_x)
            h_list.append(h_i)

        h_concat = torch.cat(h_list, dim=1)  # (B, hidden * segments)
        return self.fc_out(h_concat)


class SpinalResNet18(nn.Module):
    def __init__(self, num_classes: int = 26):
        super().__init__()
        base = models.resnet18(weights=None)
        base.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        base.maxpool = nn.Identity()                       # keep resolution 28×28 → 7×7 at the end
        self.backbone = nn.Sequential(*list(base.children())[:-1])  # drop the original FC

        with torch.no_grad():
            dummy = torch.zeros(1, 1, 28, 28)
            feat_dim = self.backbone(dummy).flatten(1).shape[1]     # should be 512

        self.classifier = SpinalClassifier(feat_dim, hidden=256, segments=4, num_classes=num_classes)

    def forward(self, x):
        feats = self.backbone(x)      # (B, 512, 1, 1)
        return self.classifier(feats)
model = SpinalResNet18().to(device)
print("Params (Spinal ResNet‑18):", sum(p.numel() for p in model.parameters())/1e6, "M")

Params (Spinal ResNet‑18): 11.787226 M


In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=AMP)

  scaler = torch.cuda.amp.GradScaler(enabled=AMP)


In [6]:
@torch.no_grad()
def evaluate():
    model.eval(); corr=tot=0; loss_sum=0
    loop = tqdm(test_loader, leave=False)
    for imgs,targets in loop:
        imgs,targets = imgs.to(device), (targets-1).to(device)
        with torch.cuda.amp.autocast(enabled=AMP):
            out = model(imgs)
            loss = criterion(out,targets)
        loss_sum += loss.item()*imgs.size(0)
        corr += (out.argmax(1)==targets).sum().item(); tot += targets.size(0)
        loop.set_description("Eval"); loop.set_postfix(acc=100*corr/tot)
    return loss_sum/tot, corr/tot

def train_one_epoch(epoch):
    model.train(); run_loss=corr=tot=0
    loop = tqdm(train_loader, leave=False)
    for imgs,targets in loop:
        imgs,targets = imgs.to(device), (targets-1).to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=AMP):
            outputs = model(imgs)
            loss = criterion(outputs, targets)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        run_loss += loss.item()*imgs.size(0)
        corr += (outputs.argmax(1)==targets).sum().item(); tot += targets.size(0)
        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=run_loss/tot, acc=100*corr/tot)

best_acc=0
for ep in range(1,EPOCHS+1):
    train_one_epoch(ep)
    val_loss,val_acc = evaluate()
    scheduler.step()
    print(f"Ep {ep}: val_acc={val_acc*100:.2f}%")
    if val_acc > best_acc:
        best_acc = val_acc; torch.save(model.state_dict(), "best_spinal_resnet_emnist.pth")
        print("✓ Saved best model")

print("Best accuracy:", best_acc*100)


  0%|          | 0/488 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=AMP):


  0%|          | 0/82 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=AMP):


Ep 1: val_acc=90.27%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 2: val_acc=93.22%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 3: val_acc=93.59%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 4: val_acc=94.55%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 5: val_acc=94.05%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 6: val_acc=94.27%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 7: val_acc=94.37%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 8: val_acc=94.89%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 9: val_acc=94.93%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 10: val_acc=94.87%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 11: val_acc=95.31%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 12: val_acc=95.26%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 13: val_acc=95.23%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 14: val_acc=95.28%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 15: val_acc=95.32%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 16: val_acc=95.57%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 17: val_acc=95.52%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 18: val_acc=95.65%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 19: val_acc=95.81%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 20: val_acc=95.78%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 21: val_acc=95.92%
✓ Saved best model


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 22: val_acc=95.83%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 23: val_acc=95.91%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 24: val_acc=95.88%


  0%|          | 0/488 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

Ep 25: val_acc=95.91%
Best accuracy: 95.91826923076923
