# Kaggle → Colab → Anywhere: DCGAN Starter
Szybki przepływ z checkpointami i sync (Kaggle Datasets / HF Hub / W&B).

In [None]:
import os, sys, subprocess
def pip_install(pkgs): subprocess.check_call([sys.executable,'-m','pip','install','--quiet']+pkgs)
try:
    import torch, torchvision  # noqa
except Exception:
    pip_install(['torch','torchvision'])
pip_install(['tqdm','omegaconf','wandb','kaggle','huggingface_hub'])
print('Setup OK')


In [None]:
import pathlib, platform, torch
BASE=pathlib.Path('.')
CKPT=(BASE/'checkpoints'); LOGS=(BASE/'logs'); DATA=(BASE/'data')
for d in [CKPT,LOGS,DATA]: d.mkdir(parents=True, exist_ok=True)
print('Python', platform.python_version(), 'Torch', torch.__version__, 'CUDA', torch.cuda.is_available())


In [None]:
import sys; sys.path.append(str((BASE/'..'/'..'/'gan-cross-platform-starter').resolve()))
sys.path.append('gan-cross-platform-starter')
from utils.checkpoint import latest_checkpoint, save_checkpoint, load_checkpoint
from utils.sync_kaggle import kaggle_dataset_push, kaggle_dataset_pull
from utils.sync_hf import hf_snapshot_upload, hf_snapshot_download
print('Helpers OK')


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform=transforms.Compose([transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_set=datasets.MNIST(root='data', train=True, download=True, transform=transform)
loader=DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
len(train_set), len(loader)


In [None]:
import torch, torch.nn as nn
nz=64
class G(nn.Module):
    def __init__(self):
        super().__init__(); self.net=nn.Sequential(
            nn.ConvTranspose2d(nz,256,4,1,0,bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.ConvTranspose2d(256,128,4,2,1,bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128,64,4,2,1,bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64,1,4,2,1,bias=False), nn.Tanh())
    def forward(self,z): return self.net(z)
class D(nn.Module):
    def __init__(self):
        super().__init__(); self.net=nn.Sequential(
            nn.Conv2d(1,64,4,2,1,bias=False), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64,128,4,2,1,bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128,256,4,2,1,bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256,1,4,1,0,bias=False))
    def forward(self,x): return self.net(x).view(-1)
device='cuda' if torch.cuda.is_available() else 'cpu'
Gnet, Dnet = G().to(device), D().to(device)
optG=torch.optim.Adam(Gnet.parameters(), lr=2e-4, betas=(0.5,0.999))
optD=torch.optim.Adam(Dnet.parameters(), lr=2e-4, betas=(0.5,0.999))
criterion=nn.BCEWithLogitsLoss()
print('Model ready on', device)


In [None]:
ckpt=latest_checkpoint('checkpoints')
global_step=0
if ckpt:
    s=load_checkpoint(ckpt,'cpu')
    Gnet.load_state_dict(s['G']); Dnet.load_state_dict(s['D'])
    optG.load_state_dict(s['optG']); optD.load_state_dict(s['optD'])
    global_step=int(s.get('step',0)); print('Resumed from', ckpt, 'step', global_step)
else:
    print('No checkpoint – start fresh')


In [None]:
from tqdm import tqdm; import wandb, os
use_wandb=bool(os.getenv('WANDB_API_KEY'))
if use_wandb: wandb.init(project=os.getenv('WANDB_PROJECT','gan-starter'), reinit=True)
epochs=1; save_every=500; nz=64; device=device
for epoch in range(epochs):
    pbar=tqdm(loader, desc=f'Epoch {epoch+1}/{epochs}', ncols=100)
    for real,_ in pbar:
        real=real.to(device); bs=real.size(0)
        z=torch.randn(bs,nz,1,1,device=device); fake=Gnet(z).detach()
        lossD=criterion(Dnet(real), torch.ones(bs, device=device)) + \
              criterion(Dnet(fake), torch.zeros(bs, device=device))
        optD.zero_grad(); lossD.backward(); optD.step()
        z=torch.randn(bs,nz,1,1,device=device); fake=Gnet(z)
        lossG=criterion(Dnet(fake), torch.ones(bs, device=device))
        optG.zero_grad(); lossG.backward(); optG.step()
        global_step+=1
        if use_wandb: wandb.log({'lossD':lossD.item(), 'lossG':lossG.item(), 'step':global_step})
        if global_step % save_every == 0:
            import pathlib
            p=pathlib.Path('checkpoints')/f'ckpt_{global_step:07d}.pt'
            save_checkpoint(str(p), {'step':global_step,'G':Gnet.state_dict(),'D':Dnet.state_dict(),
                                    'optG':optG.state_dict(),'optD':optD.state_dict(),'cfg':{'nz':nz}})
print('Done. Last step =', global_step)


### Push/Pull – Kaggle Datasets

In [None]:
import os
slug=os.getenv('KAGGLE_DATASET_SLUG','your_kaggle_username/gan-checkpoints')
try:
    kaggle_dataset_push(slug, folder='checkpoints', title='GAN Checkpoints', is_public=False)
    print('Pushed to Kaggle:', slug)
except Exception as e:
    print('Kaggle push skipped/error:', e)
try:
    out=kaggle_dataset_pull(slug, out_dir='downloaded')
    print('Pulled to:', out)
except Exception as e:
    print('Kaggle pull skipped/error:', e)


### Push/Pull – Hugging Face Hub

In [None]:
HF_TOKEN=os.getenv('HF_TOKEN'); HF_REPO=os.getenv('HF_REPO','username/gan-checkpoints')
if HF_TOKEN:
    try:
        hf_snapshot_upload(HF_REPO, local_dir='checkpoints', token=HF_TOKEN, private=True)
        print('Pushed to HF:', HF_REPO)
    except Exception as e:
        print('HF upload error:', e)
    try:
        path=hf_snapshot_download(HF_REPO, local_dir='downloaded_hf', token=HF_TOKEN)
        print('Pulled from HF to:', path)
    except Exception as e:
        print('HF download error:', e)
else:
    print('Set HF_TOKEN to use HF sync')
