In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T


def load_cifar10_batch(path):
    with open(path, 'rb') as f:
        data = np.frombuffer(f.read(), dtype=np.uint8)
    data = data.reshape(-1, 3073)
    labels = data[:, 0]
    images = data[:, 1:].reshape(-1, 3, 32, 32)
    return images, labels


class CIFARBinaryDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]                     # (3, 32, 32) uint8
        img = np.transpose(img, (1, 2, 0))        # -> (32, 32, 3)
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, int(self.labels[idx])


In [2]:
train_images, train_labels = load_cifar10_batch("dataset/cifar-10-batches-bin/train_all.bin")
test_images, test_labels = load_cifar10_batch("dataset/cifar-10-batches-bin/test_batch.bin")


# IMPORTANT

In [3]:
import timm
import torch
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

model = timm.create_model(
    "vit_tiny_patch16_224",
    pretrained=False,   # <-- NO PRETRAIN
    num_classes=10      # CIFAR-10
).to(device)


  from .autonotebook import tqdm as notebook_tqdm


---

In [3]:
import torch
import json

# 1. Load model (same architecture used before)
import timm
model = timm.create_model('vit_tiny_patch16_224', num_classes=10)
model = model.to('cuda')

# 2. Load JSON
with open('parameters/weights.json', 'r') as f:
    data = json.load(f)

# 3. Convert lists to tensors
state_dict = model.state_dict()

for key in data:
    if key not in state_dict:
        print(f"⚠️ Key {key} not in model state_dict, skipping...")
        continue
    
    arr = data[key]
    tensor = torch.tensor(arr, dtype=state_dict[key].dtype)

    # Reshape if needed (JSON lost shape sometimes)
    if tensor.shape != state_dict[key].shape:
        tensor = tensor.reshape(state_dict[key].shape)

    state_dict[key] = tensor

# 4. Load updated weights
model.load_state_dict(state_dict)

print("✅ Weights Loaded Successfully")


  from .autonotebook import tqdm as notebook_tqdm


✅ Weights Loaded Successfully


---

In [4]:
import torch.nn as nn
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
import torchvision.transforms as T

transform_train = T.Compose([
    T.RandomResizedCrop(224, scale=(0.7, 1.0)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize([0.5]*3, [0.5]*3),
])

transform_test = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.5]*3, [0.5]*3),
])


In [6]:
train_loader = DataLoader(
    CIFARBinaryDataset(train_images, train_labels, transform_train),
    batch_size=64, shuffle=True
)

test_loader = DataLoader(
    CIFARBinaryDataset(test_images, test_labels, transform_test),
    batch_size=64, shuffle=False
)


In [7]:
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.05
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)


In [8]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0

    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


In [9]:
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total


In [10]:
EPOCHS = 10   # ViT needs longer training from scratch

for epoch in range(EPOCHS):
    loss = train_one_epoch(model, train_loader, optimizer, criterion)
    acc = evaluate(model, test_loader)
    scheduler.step()
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {loss:.4f} | Test Acc: {acc*100:.2f}%")


Epoch 1/10 | Loss: 0.9265 | Test Acc: 66.50%
Epoch 2/10 | Loss: 0.8921 | Test Acc: 66.90%


KeyboardInterrupt: 

In [11]:
import json
    
state = model.state_dict()

export_dict = {}

for name, param in state.items():
    export_dict[name] = param.detach().cpu().numpy().tolist()  # convert tensor → python list

with open("parameters/weights.json", "w") as f:
    json.dump(export_dict, f, indent=2)
