In [1]:
# basic_pixel_diffusion_kaggle.py
# Hard-coded dataset + captions for your Kaggle paths.
import os, math, random, json
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, utils
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

# --------------------------------------
# HARD-CODED CONFIG (uses your Kaggle paths)
# --------------------------------------
CONFIG = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "img_size": 256,
    "batch_size": 8,
    "epochs": 20,
    "lr": 2e-4,
    "timesteps": 1000,
    "sample_steps": 50,
    "guidance_scale": 5.0,
    "cf_prob": 0.1,
    "save_dir": "./basic_diff_ckpts",
    "val_fraction": 0.1,
    # --- YOUR PATHS (hard-coded) ---
    "data_root": "/kaggle/input/combineed/kaggle/working/dataset_combined",
    "captions_json_folder": "/kaggle/input/jssonfile/kaggle/working/caption_jsons_multiGPU",
}
os.makedirs(CONFIG["save_dir"], exist_ok=True)
os.makedirs(f"{CONFIG['save_dir']}/samples", exist_ok=True)
torch.manual_seed(42); random.seed(42); np.random.seed(42)

# --------------------------------------
# Utilities
# --------------------------------------
def save_image_grid(tensor, path, nrow=4):
    grid = utils.make_grid((tensor.clamp(-1,1)+1)/2.0, nrow=nrow)
    arr = grid.mul(255).permute(1,2,0).cpu().numpy().astype("uint8")
    Image.fromarray(arr).save(path)

def plot_and_save(history, save_dir):
    epochs = range(1, len(history['train_loss'])+1)
    plt.figure(figsize=(8,4))
    plt.plot(epochs, history['train_loss'], label='train_loss')
    if history.get('val_loss') is not None:
        plt.plot(epochs, history['val_loss'], label='val_loss')
    plt.xlabel('epoch'); plt.ylabel('MSE loss'); plt.legend(); plt.title('Loss')
    plt.tight_layout(); plt.savefig(os.path.join(save_dir, 'loss_curve.png')); plt.close()

    plt.figure(figsize=(8,4))
    plt.plot(epochs, history['train_acc'], label='train_acc')
    if history.get('val_acc') is not None:
        plt.plot(epochs, history['val_acc'], label='val_acc')
    plt.xlabel('epoch'); plt.ylabel('accuracy'); plt.legend(); plt.title('Accuracy')
    plt.tight_layout(); plt.savefig(os.path.join(save_dir, 'acc_curve.png')); plt.close()
    print("Saved plots to", save_dir)

# --------------------------------------
# Caption JSON loader (robust minimal)
# --------------------------------------
def load_caption_map(json_folder: str):
    caption_map = {}
    if json_folder is None:
        return caption_map
    jfolder = Path(json_folder)
    if not jfolder.exists():
        print("Caption folder not found:", json_folder)
        return caption_map
    files = []
    if jfolder.is_file():
        files = [jfolder]
    else:
        files = [p for p in jfolder.rglob("*") if p.suffix.lower() in ('.json','.ndjson','.jsonl')]
    for jf in files:
        try:
            data = json.load(open(jf, 'r'))
            if isinstance(data, list):
                for obj in data:
                    fn = obj.get('image_filename') or (obj.get('image_path') and Path(obj.get('image_path')).name)
                    cap = obj.get('caption') or obj.get('text') or " "
                    if fn:
                        caption_map[Path(fn).name] = cap
            elif isinstance(data, dict):
                # either mapping filename->caption or single object with keys
                if 'image_filename' in data and 'caption' in data:
                    caption_map[Path(data['image_filename']).name] = data.get('caption',' ')
                else:
                    for k, v in data.items():
                        if isinstance(v, str):
                            caption_map[Path(k).name] = v
                        elif isinstance(v, dict) and 'caption' in v:
                            caption_map[Path(k).name] = v['caption']
        except Exception:
            # try line-delimited json
            try:
                with open(jf, 'r') as fh:
                    for line in fh:
                        line = line.strip()
                        if not line: continue
                        obj = json.loads(line)
                        fn = obj.get('image_filename') or (obj.get('image_path') and Path(obj.get('image_path')).name)
                        cap = obj.get('caption') or obj.get('text') or " "
                        if fn: caption_map[Path(fn).name] = cap
            except Exception:
                print("Warning: could not parse", jf)
                continue
    print(f"Loaded {len(caption_map)} captions from {json_folder}")
    return caption_map

# --------------------------------------
# Dataset using captions
# --------------------------------------
class SimpleImageDatasetWithCaptions(Dataset):
    def __init__(self, root, captions_map=None, img_size=256):
        self.root = Path(root)
        self.files = sorted([p for p in self.root.rglob("*") if p.suffix.lower() in (".jpg",".jpeg",".png")])
        self.caption_map = captions_map or {}
        # ensure every file has an entry
        for p in self.files:
            if p.name not in self.caption_map:
                self.caption_map[p.name] = " "
        self.tr = transforms.Compose([
            transforms.Resize((img_size,img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
    def __len__(self): return len(self.files)
    def __getitem__(self, idx):
        p = self.files[idx]
        img = self.tr(Image.open(p).convert("RGB"))
        cap = self.caption_map.get(p.name, " ")
        return img, cap, str(p)

# --------------------------------------
# Simple Tokenizer + Text Encoder
# --------------------------------------
class SimpleTokenizer:
    def __init__(self, max_len=40):
        self.max_len = max_len
        self.word2idx = {'<pad>':0,'<unk>':1,'<bos>':2,'<eos>':3}
        self.idx2word = {v:k for k,v in self.word2idx.items()}
        self.vocab_size = len(self.word2idx)
    def build_vocab(self, captions_list, min_freq=1, max_words=20000):
        freq = {}
        for c in captions_list:
            text = c.lower().strip()
            for ch in ['.',',',';','!','?','"',"'","(",")",":","/","\\"]:
                text = text.replace(ch,' ')
            toks = [t for t in text.split() if t!='']
            for t in toks:
                freq[t] = freq.get(t,0) + 1
        items = [w for w,f in sorted(freq.items(), key=lambda x:-x[1]) if f>=min_freq][:max_words]
        for w in items:
            if w not in self.word2idx:
                idx = len(self.word2idx); self.word2idx[w]=idx; self.idx2word[idx]=w
        self.vocab_size = len(self.word2idx)
    def encode(self, text):
        text = text.lower().strip()
        for ch in ['.',',',';','!','?','"',"'","(",")",":","/","\\"]:
            text = text.replace(ch,' ')
        toks = [t for t in text.split() if t!=''][:self.max_len]
        ids = [self.word2idx.get('<bos>')]
        for t in toks:
            ids.append(self.word2idx.get(t, self.word2idx['<unk>']))
        ids.append(self.word2idx.get('<eos>'))
        if len(ids) < self.max_len:
            ids += [self.word2idx['<pad>']] * (self.max_len - len(ids))
        else:
            ids = ids[:self.max_len]
        return torch.tensor(ids, dtype=torch.long)

class SimpleTextEncoderAvg(nn.Module):
    def __init__(self, vocab_size, emb_dim=256):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
    def forward(self, token_ids):
        e = self.emb(token_ids)
        mask = (token_ids != 0).unsqueeze(-1).float()
        summed = (e * mask).sum(1)
        lens = mask.sum(1).clamp(min=1.0)
        return summed / lens

# --------------------------------------
# Model / Conditioner / Scheduler
# --------------------------------------
def conv_block(i,o):
    return nn.Sequential(nn.Conv2d(i,o,3,1,1), nn.GroupNorm(8,o), nn.SiLU())

class SimpleUNet(nn.Module):
    def __init__(self, in_ch=3, base=64, text_dim=256, cond_dim=128):
        super().__init__()
        self.enc1 = conv_block(in_ch, base)
        self.enc2 = conv_block(base, base*2)
        self.enc3 = conv_block(base*2, base*4)
        self.mid = conv_block(base*4, base*4)
        self.dec3 = conv_block(base*8, base*4)
        self.dec2 = conv_block(base*6, base*2)
        self.dec1 = conv_block(base*3, base)
        self.out = nn.Conv2d(base, in_ch, 1)
        self.text_proj = nn.Linear(text_dim, base*4)
        self.cond_proj = nn.Conv2d(cond_dim, base*4, 1)
    def forward(self, x, cond_spatial=None, text_emb=None):
        e1 = self.enc1(x)
        e2 = self.enc2(F.avg_pool2d(e1,2))
        e3 = self.enc3(F.avg_pool2d(e2,2))
        mid = self.mid(F.avg_pool2d(e3,2))
        if cond_spatial is not None:
            cs = self.cond_proj(cond_spatial)
            cs = F.interpolate(cs, size=mid.shape[2:], mode="nearest")
            mid = mid + cs
        if text_emb is not None:
            tproj = self.text_proj(text_emb).view(text_emb.shape[0], -1, 1, 1)
            mid = mid + tproj
        d3 = F.interpolate(mid, scale_factor=2)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, scale_factor=2)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, scale_factor=2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))
        return self.out(d1)

class ConvConditioner(nn.Module):
    def __init__(self, in_c=3, cond_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, cond_dim, 3,1,1),
            nn.SiLU(),
            nn.Conv2d(cond_dim, cond_dim, 3,1,1),
            nn.SiLU()
        )
    def forward(self, img):
        return self.net(img)

class SimpleScheduler:
    def __init__(self, timesteps=1000, device="cpu"):
        betas = torch.linspace(1e-4, 0.02, timesteps, device=device)
        alphas = 1 - betas
        self.alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.timesteps = timesteps
    def q_sample(self, x0, t, noise):
        a = self.alphas_cumprod[t].view(-1,1,1,1)
        return a.sqrt()*x0 + (1-a).sqrt()*noise

@torch.no_grad()
def ddim_sample(model, sched, shape, cond_spatial, text_emb, steps, device, scale):
    x = torch.randn(shape, device=device)
    idxs = torch.linspace(sched.timesteps-1, 0, steps).long().to(device)
    for i, t in enumerate(idxs):
        eps_c = model(x, cond_spatial, text_emb)
        eps_uc = model(x, cond_spatial, None)
        eps = eps_uc + scale*(eps_c-eps_uc)
        a = sched.alphas_cumprod[t].sqrt().view(1,1,1,1)
        x = (x - (1-a)*eps) / a
    return x

# --------------------------------------
# Metric utilities
# --------------------------------------
def noise_prediction_accuracy(pred, true, tol=0.05):
    diff = (pred - true).abs()
    correct = (diff <= tol).float()
    return correct.mean().item()

# --------------------------------------
# TRAINING LOOP with metrics & val
# --------------------------------------
def train():
    device = CONFIG["device"]
    captions_map = load_caption_map(CONFIG["captions_json_folder"])
    ds = SimpleImageDatasetWithCaptions(CONFIG["data_root"], captions_map, img_size=CONFIG["img_size"])
    N = len(ds)
    if N == 0:
        raise RuntimeError(f"No images found in {CONFIG['data_root']}")
    val_n = max(int(CONFIG['val_fraction'] * N), 1)
    train_n = N - val_n
    train_ds, val_ds = random_split(ds, [train_n, val_n])
    train_dl = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=2, drop_last=True)
    val_dl = DataLoader(val_ds, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=2)
    # build tokenizer from dataset captions
    all_captions = [ds.caption_map.get(Path(p).name, " ") for p in ds.files]
    tok = SimpleTokenizer()
    tok.build_vocab(all_captions)
    text_enc = SimpleTextEncoderAvg(tok.vocab_size, emb_dim=256).to(device)
    cond = ConvConditioner(in_c=3, cond_dim=128).to(device)
    den = SimpleUNet(in_ch=3, base=64, text_dim=256, cond_dim=128).to(device)
    opt = torch.optim.Adam(list(den.parameters()) + list(cond.parameters()) + list(text_enc.parameters()), lr=CONFIG["lr"])
    sched = SimpleScheduler(CONFIG["timesteps"], device=device)
    mse = nn.MSELoss()
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    print(f"Training on {train_n} images, validating on {val_n} images. Device: {device}")
    for epoch in range(1, CONFIG["epochs"]+1):
        den.train(); cond.train(); text_enc.train()
        running_loss = 0.0; running_acc = 0.0; seen = 0
        pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{CONFIG['epochs']}")
        for imgs, caps, _ in pbar:
            imgs = imgs.to(device)
            B = imgs.shape[0]
            tokens = torch.stack([tok.encode(c) for c in caps], dim=0).to(device)
            text_emb = text_enc(tokens)
            drop = (torch.rand(B, device=device) < CONFIG["cf_prob"]).float().unsqueeze(1)
            text_emb_cf = text_emb * (1 - drop)
            cond_in = F.interpolate(imgs, size=(imgs.shape[2]//4, imgs.shape[3]//4), mode='bilinear', align_corners=False)
            cond_feat = cond(cond_in)
            t = torch.randint(0, CONFIG["timesteps"], (B,), device=device)
            noise = torch.randn_like(imgs)
            zt = sched.q_sample(imgs, t, noise)
            pred = den(zt, cond_feat, text_emb_cf)
            loss = mse(pred, noise)
            opt.zero_grad(); loss.backward(); opt.step()
            acc = noise_prediction_accuracy(pred.detach(), noise.detach(), tol=0.05)
            running_loss += float(loss.item()) * B
            running_acc += acc * B
            seen += B
            if seen % (CONFIG["batch_size"] * 2) == 0:
                pbar.set_postfix(train_loss=(running_loss/seen), train_acc=(running_acc/seen))
        epoch_train_loss = running_loss / max(1, seen)
        epoch_train_acc = running_acc / max(1, seen)
        history['train_loss'].append(epoch_train_loss); history['train_acc'].append(epoch_train_acc)
        # validation
        den.eval(); cond.eval(); text_enc.eval()
        val_loss = 0.0; val_acc = 0.0; vseen = 0
        with torch.no_grad():
            for imgs, caps, _ in val_dl:
                imgs = imgs.to(device)
                B = imgs.shape[0]
                tokens = torch.stack([tok.encode(c) for c in caps], dim=0).to(device)
                text_emb = text_enc(tokens)
                cond_in = F.interpolate(imgs, size=(imgs.shape[2]//4, imgs.shape[3]//4), mode='bilinear', align_corners=False)
                cond_feat = cond(cond_in)
                t = torch.randint(0, CONFIG["timesteps"], (B,), device=device)
                noise = torch.randn_like(imgs)
                zt = sched.q_sample(imgs, t, noise)
                pred = den(zt, cond_feat, text_emb)
                loss = mse(pred, noise)
                acc = noise_prediction_accuracy(pred, noise, tol=0.05)
                val_loss += float(loss.item()) * B
                val_acc += acc * B
                vseen += B
        epoch_val_loss = val_loss / max(1, vseen)
        epoch_val_acc = val_acc / max(1, vseen)
        history['val_loss'].append(epoch_val_loss); history['val_acc'].append(epoch_val_acc)
        print(f"Epoch {epoch} summary: train_loss={epoch_train_loss:.6f} train_acc={epoch_train_acc:.4f}  val_loss={epoch_val_loss:.6f} val_acc={epoch_val_acc:.4f}")
        ck = {"den": den.state_dict(), "cond": cond.state_dict(), "text_enc": text_enc.state_dict(), "opt": opt.state_dict(), "epoch": epoch}
        torch.save(ck, os.path.join(CONFIG["save_dir"], f"ckpt_epoch_{epoch}.pth"))
        plot_and_save(history, CONFIG["save_dir"])
        # save sample from validation
        try:
            sample_batch = next(iter(val_dl))
            imgs0, caps0, _ = sample_batch
            imgs0 = imgs0.to(device)[:4]; caps0 = caps0[:4]
            tokens0 = torch.stack([tok.encode(c) for c in caps0], dim=0).to(device)
            t_emb = text_enc(tokens0)
            cond_in = F.interpolate(imgs0, size=(imgs0.shape[2]//4, imgs0.shape[3]//4), mode='bilinear', align_corners=False)
            cond_spatial_sample = cond(cond_in)
            x = ddim_sample(den, sched, imgs0.shape, cond_spatial_sample, t_emb, steps=CONFIG['sample_steps'], device=device, scale=CONFIG['guidance_scale'])
            save_image_grid(x.cpu(), os.path.join(CONFIG['save_dir'], "samples", f"epoch_{epoch}_val_sample.png"))
        except Exception as e:
            print("Sample save failed:", e)
    plot_and_save(history, CONFIG["save_dir"])
    print("Training complete. History saved in", CONFIG["save_dir"])
    return history

# --------------------------------------
# Sampling helpers (same as before)
# --------------------------------------
@torch.no_grad()
def sample_text2img(prompt="a painting of a cat", ckpt=None, out=None):
    device = CONFIG["device"]
    if ckpt is None:
        ckpt = sorted(Path(CONFIG["save_dir"]).glob("ckpt_epoch_*.pth"))[-1]
    state = torch.load(str(ckpt), map_location=device)
    den = SimpleUNet().to(device); den.load_state_dict(state["den"])
    cond = ConvConditioner().to(device); cond.load_state_dict(state["cond"])
    text_enc = SimpleTextEncoderAvg(state["text_enc"]['emb.weight'].shape[0], emb_dim=256).to(device)
    text_enc.load_state_dict(state["text_enc"])
    tokens = torch.zeros((4,40), dtype=torch.long, device=device)
    text_emb = text_enc(tokens)
    cond_spatial = torch.zeros((4,128,CONFIG["img_size"]//4, CONFIG["img_size"]//4), device=device)
    sched = SimpleScheduler(CONFIG["timesteps"], device=device)
    x = ddim_sample(den, sched, (4,3,CONFIG["img_size"],CONFIG["img_size"]), cond_spatial, text_emb, steps=CONFIG["sample_steps"], device=device, scale=CONFIG["guidance_scale"])
    out = out or os.path.join(CONFIG["save_dir"], "samples", "text2img.png")
    save_image_grid(x.cpu(), out)
    print("Saved text2img sample to", out)

@torch.no_grad()
def sample_img2img(source_img_path, ckpt=None, out=None, strength=0.6):
    device = CONFIG["device"]
    if ckpt is None:
        ckpt = sorted(Path(CONFIG["save_dir"]).glob("ckpt_epoch_*.pth"))[-1]
    state = torch.load(str(ckpt), map_location=device)
    den = SimpleUNet().to(device); den.load_state_dict(state["den"])
    cond = ConvConditioner().to(device); cond.load_state_dict(state["cond"])
    text_enc = SimpleTextEncoderAvg(state["text_enc"]['emb.weight'].shape[0], emb_dim=256).to(device)
    text_enc.load_state_dict(state["text_enc"])
    tr = transforms.Compose([transforms.Resize((CONFIG["img_size"],CONFIG["img_size"])), transforms.ToTensor(), transforms.Normalize([0.5]*3,[0.5]*3)])
    img = tr(Image.open(source_img_path).convert("RGB")).unsqueeze(0).to(device)
    cond_in = F.interpolate(img, size=(img.shape[2]//4, img.shape[3]//4), mode='bilinear', align_corners=False)
    cond_feat = cond(cond_in)
    sched = SimpleScheduler(CONFIG["timesteps"], device=device)
    t_int = int(strength * (sched.timesteps - 1))
    noise = torch.randn_like(img)
    zt = sched.q_sample(img, torch.tensor([t_int], device=device), noise)
    out_img = ddim_sample(den, sched, img.shape, cond_feat, None, steps=CONFIG['sample_steps'], device=device, scale=1.0)
    out = out or os.path.join(CONFIG["save_dir"], "samples", "img2img.png")
    save_image_grid(out_img.cpu(), out)
    print("Saved img2img to", out)


print("Starting training on Kaggle dataset paths...")
history = train()
print("Sampling examples...")
try:
    sample_text2img()
except Exception as e:
    print("Text2Img sampling failed:", e)
try:
    # try to sample using a real example if present
    example_path = next(Path(CONFIG['data_root']).rglob("*.jpg"), None)
    if example_path is not None:
        sample_img2img(str(example_path))
    else:
        print("No example image found for img2img sampling.")
except Exception as e:
    print("Img2Img sampling failed:", e)



Starting training on Kaggle dataset paths...
Loaded 2109 captions from /kaggle/input/jssonfile/kaggle/working/caption_jsons_multiGPU
Training on 6336 images, validating on 704 images. Device: cuda


Epoch 1/20: 100%|██████████| 792/792 [03:15<00:00,  4.04it/s, train_acc=0.24, train_loss=0.0889] 


Epoch 1 summary: train_loss=0.088865 train_acc=0.2396  val_loss=0.053972 val_acc=0.3090
Saved plots to ./basic_diff_ckpts


Epoch 2/20: 100%|██████████| 792/792 [03:21<00:00,  3.93it/s, train_acc=0.346, train_loss=0.0469]


Epoch 2 summary: train_loss=0.046893 train_acc=0.3459  val_loss=0.042564 val_acc=0.3966
Saved plots to ./basic_diff_ckpts


Epoch 3/20: 100%|██████████| 792/792 [03:21<00:00,  3.92it/s, train_acc=0.393, train_loss=0.0435]


Epoch 3 summary: train_loss=0.043489 train_acc=0.3926  val_loss=0.041409 val_acc=0.4082
Saved plots to ./basic_diff_ckpts


Epoch 4/20: 100%|██████████| 792/792 [03:22<00:00,  3.92it/s, train_acc=0.419, train_loss=0.0421]


Epoch 4 summary: train_loss=0.042135 train_acc=0.4186  val_loss=0.047362 val_acc=0.4319
Saved plots to ./basic_diff_ckpts


Epoch 5/20: 100%|██████████| 792/792 [03:21<00:00,  3.92it/s, train_acc=0.441, train_loss=0.0397]


Epoch 5 summary: train_loss=0.039732 train_acc=0.4412  val_loss=0.033756 val_acc=0.4300
Saved plots to ./basic_diff_ckpts


Epoch 6/20: 100%|██████████| 792/792 [03:21<00:00,  3.93it/s, train_acc=0.468, train_loss=0.0375]


Epoch 6 summary: train_loss=0.037522 train_acc=0.4679  val_loss=0.039232 val_acc=0.4714
Saved plots to ./basic_diff_ckpts


Epoch 7/20: 100%|██████████| 792/792 [03:22<00:00,  3.92it/s, train_acc=0.475, train_loss=0.0377]


Epoch 7 summary: train_loss=0.037741 train_acc=0.4752  val_loss=0.031833 val_acc=0.4817
Saved plots to ./basic_diff_ckpts


Epoch 8/20: 100%|██████████| 792/792 [03:21<00:00,  3.93it/s, train_acc=0.482, train_loss=0.0381]


Epoch 8 summary: train_loss=0.038060 train_acc=0.4818  val_loss=0.031130 val_acc=0.4960
Saved plots to ./basic_diff_ckpts


Epoch 9/20: 100%|██████████| 792/792 [03:22<00:00,  3.92it/s, train_acc=0.507, train_loss=0.0353]


Epoch 9 summary: train_loss=0.035314 train_acc=0.5067  val_loss=0.035065 val_acc=0.5258
Saved plots to ./basic_diff_ckpts


Epoch 10/20: 100%|██████████| 792/792 [03:21<00:00,  3.92it/s, train_acc=0.519, train_loss=0.0332]


Epoch 10 summary: train_loss=0.033245 train_acc=0.5185  val_loss=0.033424 val_acc=0.5028
Saved plots to ./basic_diff_ckpts


Epoch 11/20: 100%|██████████| 792/792 [03:22<00:00,  3.92it/s, train_acc=0.542, train_loss=0.032] 


Epoch 11 summary: train_loss=0.032002 train_acc=0.5422  val_loss=0.030329 val_acc=0.5526
Saved plots to ./basic_diff_ckpts


Epoch 12/20: 100%|██████████| 792/792 [03:21<00:00,  3.92it/s, train_acc=0.538, train_loss=0.0327]


Epoch 12 summary: train_loss=0.032714 train_acc=0.5383  val_loss=0.032893 val_acc=0.5030
Saved plots to ./basic_diff_ckpts


Epoch 13/20: 100%|██████████| 792/792 [03:22<00:00,  3.92it/s, train_acc=0.538, train_loss=0.034] 


Epoch 13 summary: train_loss=0.033979 train_acc=0.5375  val_loss=0.029260 val_acc=0.5750
Saved plots to ./basic_diff_ckpts


Epoch 14/20: 100%|██████████| 792/792 [03:22<00:00,  3.92it/s, train_acc=0.559, train_loss=0.0324]


Epoch 14 summary: train_loss=0.032378 train_acc=0.5593  val_loss=0.023800 val_acc=0.5855
Saved plots to ./basic_diff_ckpts


Epoch 15/20: 100%|██████████| 792/792 [03:21<00:00,  3.94it/s, train_acc=0.566, train_loss=0.0305]


Epoch 15 summary: train_loss=0.030522 train_acc=0.5657  val_loss=0.028301 val_acc=0.5750
Saved plots to ./basic_diff_ckpts


Epoch 16/20: 100%|██████████| 792/792 [03:21<00:00,  3.93it/s, train_acc=0.568, train_loss=0.0305]


Epoch 16 summary: train_loss=0.030539 train_acc=0.5683  val_loss=0.025819 val_acc=0.5989
Saved plots to ./basic_diff_ckpts


Epoch 17/20: 100%|██████████| 792/792 [03:21<00:00,  3.93it/s, train_acc=0.572, train_loss=0.0313]


Epoch 17 summary: train_loss=0.031271 train_acc=0.5723  val_loss=0.030193 val_acc=0.5778
Saved plots to ./basic_diff_ckpts


Epoch 18/20: 100%|██████████| 792/792 [03:21<00:00,  3.93it/s, train_acc=0.578, train_loss=0.03]  


Epoch 18 summary: train_loss=0.030047 train_acc=0.5783  val_loss=0.028689 val_acc=0.5853
Saved plots to ./basic_diff_ckpts


Epoch 19/20: 100%|██████████| 792/792 [03:21<00:00,  3.93it/s, train_acc=0.59, train_loss=0.0287] 


Epoch 19 summary: train_loss=0.028667 train_acc=0.5896  val_loss=0.027965 val_acc=0.5857
Saved plots to ./basic_diff_ckpts


Epoch 20/20: 100%|██████████| 792/792 [03:21<00:00,  3.93it/s, train_acc=0.591, train_loss=0.0296]


Epoch 20 summary: train_loss=0.029589 train_acc=0.5905  val_loss=0.029487 val_acc=0.5992
Saved plots to ./basic_diff_ckpts
Saved plots to ./basic_diff_ckpts
Training complete. History saved in ./basic_diff_ckpts
Sampling examples...
Saved text2img sample to ./basic_diff_ckpts/samples/text2img.png
Saved img2img to ./basic_diff_ckpts/samples/img2img.png
