# CIFAR-10 — Colab Quickstart (EAUT)
Chọn **Runtime → GPU** rồi chạy các cell.


In [None]:
!pip -q install -U scikit-learn matplotlib pandas thop


In [None]:
# Ghi file dự án vào /content/cifar10_pack
import os, json, textwrap
root="/content/cifar10_pack"; os.makedirs(root, exist_ok=True)
open(f"/mnt/data/cifar10_pack/train_cifar10.py","w").write("\n# train_cifar10.py\nimport os, json, time, random, argparse\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader, random_split\nfrom torchvision import datasets, transforms, models\nfrom torch.optim.lr_scheduler import CosineAnnealingLR, StepLR\nfrom sklearn.metrics import f1_score\nfrom models_simple import SimpleCNN\n\ndef set_seed(seed: int = 42):\n    import torch.backends.cudnn as cudnn\n    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)\n    cudnn.deterministic = False; cudnn.benchmark = True\n\ndef cifar10_loaders(data_root, batch_size=128, val_ratio=0.1, seed=42, aug=True, num_workers=2):\n    mean = (0.4914, 0.4822, 0.4465); std  = (0.2470, 0.2435, 0.2616)\n    train_tf = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] if aug else []\n    train_tf += [transforms.ToTensor(), transforms.Normalize(mean, std)]\n    test_tf = [transforms.ToTensor(), transforms.Normalize(mean, std)]\n    train_full = datasets.CIFAR10(root=data_root, train=True, download=True, transform=transforms.Compose(train_tf))\n    test_set   = datasets.CIFAR10(root=data_root, train=False, download=True, transform=transforms.Compose(test_tf))\n    g = torch.Generator().manual_seed(seed); n = len(train_full); n_val = int(n * val_ratio); n_train = n - n_val\n    train_set, val_set = random_split(train_full, [n_train, n_val], generator=g)\n    dl = lambda ds, sh: DataLoader(ds, batch_size=batch_size, shuffle=sh, num_workers=num_workers, pin_memory=True)\n    return dl(train_set, True), dl(val_set, False), dl(test_set, False)\n\ndef build_model(name: str, num_classes=10):\n    name = name.lower()\n    if name == \"simplecnn\": return SimpleCNN(num_classes=num_classes)\n    if name == \"resnet18\":  m = models.resnet18(weights=None); m.fc = nn.Linear(m.fc.in_features, num_classes); return m\n    if name == \"mobilenetv2\": m = models.mobilenet_v2(weights=None); m.classifier[1] = nn.Linear(m.last_channel, num_classes); return m\n    if name == \"efficientnet_b0\": m = models.efficientnet_b0(weights=None); m.classifier[1] = nn.Linear(m.classifier[1].in_features, num_classes); return m\n    if name == \"vgg11_bn\": m = models.vgg11_bn(weights=None); m.classifier[6] = nn.Linear(m.classifier[6].in_features, num_classes); return m\n    raise ValueError(f\"Unknown model: {name}\")\n\ndef count_params_m(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6\n\ndef get_optimizer(model, name, lr, wd, momentum=0.9):\n    name = name.lower()\n    return torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd, nesterov=True) if name==\"sgd\" \\\n        else torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)\n\ndef get_scheduler(optimizer, name, epochs, step_size=30, gamma=0.1):\n    name = name.lower()\n    return CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.0) if name==\"cosine\" \\\n        else (StepLR(optimizer, step_size=step_size, gamma=gamma) if name==\"step\" else None)\n\ndef train_one_epoch(model, loader, device, optimizer, scaler=None):\n    model.train(); total=correct=0; loss_sum=0.0\n    for x, y in loader:\n        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n        optimizer.zero_grad(set_to_none=True)\n        if scaler:\n            with torch.autocast(device_type=device.type, enabled=True):\n                logits = model(x); loss = F.cross_entropy(logits, y, label_smoothing=0.1)\n            scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()\n        else:\n            logits = model(x); loss = F.cross_entropy(logits, y, label_smoothing=0.1)\n            loss.backward(); optimizer.step()\n        loss_sum += loss.item()*y.size(0); pred = logits.argmax(1)\n        correct += (pred==y).sum().item(); total += y.size(0)\n    return loss_sum/total, correct/total\n\n@torch.no_grad()\ndef evaluate(model, loader, device):\n    model.eval(); total=correct=0; loss_sum=0.0; ys=[]; ps=[]\n    for x,y in loader:\n        x,y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n        logits = model(x); loss = F.cross_entropy(logits, y)\n        loss_sum += loss.item()*y.size(0); pred=logits.argmax(1)\n        correct += (pred==y).sum().item(); total += y.size(0)\n        ys.append(y.cpu()); ps.append(pred.cpu())\n    import torch as T; y_true=T.cat(ys).numpy(); y_pred=T.cat(ps).numpy()\n    return loss_sum/total, correct/total, f1_score(y_true, y_pred, average=\"macro\")\n\ndef measure_latency(model, device, input_size=(1,3,32,32), warmup=30, iters=100):\n    import time; x=torch.randn(*input_size).to(device)\n    for _ in range(warmup): _=model(x)\n    if device.type==\"cuda\": torch.cuda.synchronize()\n    t0=time.time(); \n    for _ in range(iters): _=model(x)\n    if device.type==\"cuda\": torch.cuda.synchronize()\n    return (time.time()-t0)*1000.0/iters\n\ndef maybe_flops(model, device):\n    try:\n        from thop import profile\n        x = torch.randn(1,3,32,32).to(device)\n        flops, _ = profile(model, inputs=(x,), verbose=False)\n        return flops/1e9\n    except Exception: return None\n\ndef main():\n    ap = argparse.ArgumentParser()\n    ap.add_argument(\"--data_root\", type=str, default=\"./data\")\n    ap.add_argument(\"--outdir\", type=str, default=\"runs/exp_cifar10\")\n    ap.add_argument(\"--model\", type=str, default=\"resnet18\",\n                    choices=[\"simplecnn\",\"resnet18\",\"mobilenetv2\",\"efficientnet_b0\",\"vgg11_bn\"])\n    ap.add_argument(\"--epochs\", type=int, default=20)\n    ap.add_argument(\"--batch\", type=int, default=128)\n    ap.add_argument(\"--opt\", type=str, default=\"sgd\", choices=[\"sgd\",\"adamw\"])\n    ap.add_argument(\"--lr\", type=float, default=0.1)\n    ap.add_argument(\"--wd\", type=float, default=5e-4)\n    ap.add_argument(\"--sched\", type=str, default=\"cosine\", choices=[\"cosine\",\"step\",\"none\"])\n    ap.add_argument(\"--step_size\", type=int, default=30)\n    ap.add_argument(\"--gamma\", type=float, default=0.1)\n    ap.add_argument(\"--val_ratio\", type=float, default=0.1)\n    ap.add_argument(\"--seed\", type=int, default=42)\n    ap.add_argument(\"--num_workers\", type=int, default=2)\n    ap.add_argument(\"--no_aug\", action=\"store_true\")\n    ap.add_argument(\"--amp\", action=\"store_true\")\n    args = ap.parse_args()\n\n    set_seed(args.seed); os.makedirs(args.outdir, exist_ok=True)\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    train_loader, val_loader, test_loader = cifar10_loaders(args.data_root, args.batch, args.val_ratio, args.seed, not args.no_aug, args.num_workers)\n    model = build_model(args.model, 10).to(device)\n    opt = get_optimizer(model, args.opt, args.lr, args.wd)\n    sched = None if args.sched==\"none\" else get_scheduler(opt, args.sched, args.epochs, args.step_size, args.gamma)\n    scaler = torch.cuda.amp.GradScaler() if (args.amp and device.type==\"cuda\") else None\n\n    hparams = vars(args).copy(); hparams[\"device\"]=str(device); hparams[\"params_m\"]=round(count_params_m(model),4)\n    json.dump(hparams, open(os.path.join(args.outdir,\"hparams.json\"),\"w\"), indent=2)\n\n    log_path = os.path.join(args.outdir,\"train_log.csv\")\n    open(log_path,\"w\").write(\"epoch,train_loss,train_acc,val_loss,val_acc,val_f1,lr\\n\")\n\n    best_acc=-1.0\n    for ep in range(1, args.epochs+1):\n        tl, ta = train_one_epoch(model, train_loader, device, opt, scaler)\n        vl, va, vf1 = evaluate(model, val_loader, device)\n        lr_now = opt.param_groups[0][\"lr\"]\n        if sched: sched.step()\n        open(log_path,\"a\").write(f\"{ep},{tl:.6f},{ta:.6f},{vl:.6f},{va:.6f},{vf1:.6f},{lr_now:.6f}\\\\n\")\n        if va>best_acc:\n            best_acc=va; torch.save({\"model\": model.state_dict(), \"epoch\": ep, \"val_acc\": va}, os.path.join(args.outdir,\"best.pt\"))\n        print(f\"[{ep:03d}/{args.epochs}] train_acc={ta:.4f} val_acc={va:.4f} val_f1={vf1:.4f}\")\n\n    ckpt = torch.load(os.path.join(args.outdir,\"best.pt\"), map_location=device)\n    model.load_state_dict(ckpt[\"model\"])\n    tl, ta, tf1 = evaluate(model, test_loader, device)\n\n    model.eval()\n    lat_cpu = measure_latency(model.to(\"cpu\"), torch.device(\"cpu\"))\n    lat_gpu = measure_latency(model.to(device), device) if torch.cuda.is_available() else None\n    flops_g = maybe_flops(model.to(device), device)\n\n    summary = {\n        \"model\": args.model, \"dataset\": \"cifar10\", \"epochs\": args.epochs, \"batch\": args.batch,\n        \"opt\": args.opt, \"lr0\": args.lr, \"wd\": args.wd, \"sched\": args.sched, \"aug\": int(not args.no_aug), \"seed\": args.seed,\n        \"val_best_acc\": round(best_acc,4), \"test_acc\": round(ta,4), \"test_f1\": round(tf1,4),\n        \"params_m\": hparams[\"params_m\"],\n        \"flops_g\": round(flops_g,4) if flops_g is not None else None,\n        \"latency_cpu_ms\": round(lat_cpu,4) if lat_cpu is not None else None,\n        \"latency_gpu_ms\": round(lat_gpu,4) if lat_gpu is not None else None,\n        \"outdir\": args.outdir\n    }\n    print(\"=== RUN SUMMARY ===\"); [print(f\"{k}: {v}\") for k,v in summary.items()]\n    res_csv = os.path.join(args.outdir,\"result_summary.csv\"); write_header=not os.path.exists(res_csv)\n    with open(res_csv,\"a\") as f:\n        if write_header:\n            f.write(\"model,dataset,epochs,batch,opt,lr0,wd,sched,aug,seed,val_best_acc,test_acc,test_f1,params_m,flops_g,latency_cpu_ms,latency_gpu_ms,outdir\\\\n\")\n        f.write(\",\".join(str(summary[k]) for k in [\"model\",\"dataset\",\"epochs\",\"batch\",\"opt\",\"lr0\",\"wd\",\"sched\",\"aug\",\"seed\",\"val_best_acc\",\"test_acc\",\"test_f1\",\"params_m\",\"flops_g\",\"latency_cpu_ms\",\"latency_gpu_ms\",\"outdir\"])+\"\\\\n\")\n\nif __name__ == \"__main__\":\n    main()\n")
open(f"/mnt/data/cifar10_pack/models_simple.py","w").write("\n# models_simple.py\nimport torch, torch.nn as nn\n\nclass SimpleCNN(nn.Module):\n    def __init__(self, num_classes=10):\n        super().__init__()\n        self.features = nn.Sequential(\n            nn.Conv2d(3,32,3,padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),\n            nn.Conv2d(32,32,3,padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),\n            nn.MaxPool2d(2),\n            nn.Conv2d(32,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),\n            nn.Conv2d(64,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),\n            nn.MaxPool2d(2),\n            nn.Conv2d(64,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),\n            nn.MaxPool2d(2),\n        )\n        self.classifier = nn.Sequential(\n            nn.Dropout(0.3), nn.Linear(128*4*4,256), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(256,num_classes)\n        )\n    def forward(self,x):\n        x=self.features(x); x=x.view(x.size(0),-1); return self.classifier(x)\n")
open(f"/mnt/data/cifar10_pack/eval_cifar10.py","w").write("\n# eval_cifar10.py\nimport os, argparse, json, time, numpy as np, torch, torch.nn as nn\nfrom torchvision import datasets, transforms, models\nfrom sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score\nimport matplotlib.pyplot as plt\nfrom models_simple import SimpleCNN\n\ndef build_model(name: str, num_classes=10):\n    name = name.lower()\n    if name==\"simplecnn\": return SimpleCNN(num_classes=num_classes)\n    if name==\"resnet18\":  m=models.resnet18(weights=None); m.fc=nn.Linear(m.fc.in_features,num_classes); return m\n    if name==\"mobilenetv2\": m=models.mobilenet_v2(weights=None); m.classifier[1]=nn.Linear(m.last_channel,num_classes); return m\n    if name==\"efficientnet_b0\": m=models.efficientnet_b0(weights=None); m.classifier[1]=nn.Linear(m.classifier[1].in_features,num_classes); return m\n    if name==\"vgg11_bn\": m=models.vgg11_bn(weights=None); m.classifier[6]=nn.Linear(m.classifier[6].in_features,num_classes); return m\n    raise ValueError(f\"Unknown model: {name}\")\n\ndef main():\n    ap = argparse.ArgumentParser()\n    ap.add_argument(\"--data_root\", type=str, default=\"./data\")\n    ap.add_argument(\"--checkpoint\", type=str, required=True)\n    ap.add_argument(\"--model\", type=str, required=True, choices=[\"simplecnn\",\"resnet18\",\"mobilenetv2\",\"efficientnet_b0\",\"vgg11_bn\"])\n    ap.add_argument(\"--outdir\", type=str, default=\"runs/eval\"); ap.add_argument(\"--batch\", type=int, default=256)\n    args = ap.parse_args()\n\n    os.makedirs(args.outdir, exist_ok=True); device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    mean=(0.4914,0.4822,0.4465); std=(0.2470,0.2435,0.2616)\n    tfm = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean,std)])\n    test_set = datasets.CIFAR10(root=args.data_root, train=False, download=True, transform=tfm)\n    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch, shuffle=False, num_workers=2, pin_memory=True)\n\n    model = build_model(args.model, num_classes=10).to(device)\n    ckpt = torch.load(args.checkpoint, map_location=device); model.load_state_dict(ckpt[\"model\"]); model.eval()\n\n    ys, ps = [], []\n    with torch.no_grad():\n        for x,y in test_loader:\n            x,y = x.to(device), y.to(device)\n            pred = model(x).argmax(1)\n            ys.append(y.cpu().numpy()); ps.append(pred.cpu().numpy())\n    import numpy as np\n    y=np.concatenate(ys); p=np.concatenate(ps)\n    acc=accuracy_score(y,p); f1m=f1_score(y,p,average=\"macro\")\n    print(f\"Test Accuracy: {acc:.4f} | F1-macro: {f1m:.4f}\"); print(classification_report(y,p,digits=4))\n\n    cm = confusion_matrix(y,p)\n    import matplotlib.pyplot as plt\n    fig = plt.figure(); plt.imshow(cm); plt.title(\"Confusion Matrix (CIFAR-10)\")\n    plt.xlabel(\"Predicted\"); plt.ylabel(\"True\"); plt.colorbar()\n    for (i,j),z in np.ndenumerate(cm): plt.text(j,i,str(z),ha='center',va='center')\n    cm_path = os.path.join(args.outdir,\"confusion_matrix.png\"); fig.savefig(cm_path,dpi=300,bbox_inches=\"tight\"); print(\"Saved:\", cm_path)\n\n    params_m = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6\n    try:\n        from thop import profile\n        x = torch.randn(1,3,32,32).to(device); flops,_=profile(model, inputs=(x,), verbose=False); flops_g=flops/1e9\n    except Exception: flops_g=None\n\n    def measure(m,dev,iters=100,warmup=30):\n        x=torch.randn(1,3,32,32).to(dev)\n        for _ in range(warmup): _=m(x)\n        if dev.type==\"cuda\": torch.cuda.synchronize()\n        t0=time.time()\n        for _ in range(iters): _=m(x)\n        if dev.type==\"cuda\": torch.cuda.synchronize()\n        return (time.time()-t0)*1000/iters\n\n    m_cpu = build_model(args.model,10).to(\"cpu\"); m_cpu.load_state_dict(ckpt[\"model\"], strict=False)\n    lat_cpu=measure(m_cpu, torch.device(\"cpu\")); lat_gpu=None\n    if torch.cuda.is_available(): lat_gpu=measure(model, device)\n\n    summary={\"model\":args.model,\"dataset\":\"cifar10\",\"test_acc\":round(acc,4),\"test_f1\":round(f1m,4),\n             \"params_m\":round(params_m,4),\"flops_g\":(round(flops_g,4) if flops_g is not None else None),\n             \"latency_cpu_ms\":round(lat_cpu,4),\"latency_gpu_ms\":(round(lat_gpu,4) if lat_gpu is not None else None),\n             \"checkpoint\":args.checkpoint}\n    json.dump(summary, open(os.path.join(args.outdir,\"eval_summary.json\"),\"w\"), indent=2); print(\"Saved eval summary to eval_summary.json\")\nif __name__ == \"__main__\": main()\n")
print("Wrote:", os.listdir(root))


In [None]:
# Train ResNet-18 (20 epochs demo)
%cd /content/cifar10_pack
!python train_cifar10.py --model resnet18 --epochs 20 --batch 128 --opt sgd --lr 0.1 --wd 5e-4 --sched cosine --amp --outdir /content/runs/resnet18_c10


In [None]:
# Evaluate + confusion matrix
!python eval_cifar10.py --checkpoint /content/runs/resnet18_c10/best.pt --model resnet18 --outdir /content/runs/resnet18_c10


In [None]:
# Vẽ đường cong học
import pandas as pd, matplotlib.pyplot as plt
import os
log='/content/runs/resnet18_c10/train_log.csv'
df=pd.read_csv(log)
plt.figure(); plt.plot(df.epoch, df.train_acc, label='train'); plt.plot(df.epoch, df.val_acc, label='val'); plt.legend(); plt.title('ResNet-18 CIFAR-10'); plt.savefig('/content/runs/resnet18_c10/acc_curve.png', dpi=200)
print('Saved curve -> /content/runs/resnet18_c10/acc_curve.png')
