In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/clouds-ml/src

Mounted at /content/drive
/content/drive/MyDrive/clouds-ml/src


In [None]:
import os, math, random, time
import pandas as pd
import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.amp import GradScaler, autocast

from tqdm.auto import tqdm
from config import BASE_DIR, DATA_PROCESSED, MODELS_DIR

In [None]:
CSV_PATH  = os.path.join(DATA_PROCESSED, "sample_10k_ready.csv")
SAVE_PATH = os.path.join(MODELS_DIR, "resnet18_big10k.pth")
os.makedirs(MODELS_DIR, exist_ok=True)

In [None]:
df = pd.read_csv(CSV_PATH)
def to_full(p: str | None):
    if isinstance(p, str):
        return os.path.join(BASE_DIR, p)
    return None

df["full_path"] = df["local_path"].apply(to_full)
df["full_path"] = df["full_path"].str.replace(
    "/content/drive/MyDrive/clouds-ml/data/images_10k", "/content/images_10k", regex=False
)
df = df[df["full_path"].apply(lambda p: isinstance(p, str) and os.path.isfile(p))].reset_index(drop=True)


In [None]:
CLOUD_COLS = [
    "altocumulus","altostratus","cirrocumulus","cirrostratus",
    "cirrus","cumulonimbus","cumulus","nimbostratus","stratocumulus","stratus"
]
num_classes = len(CLOUD_COLS)

class CloudDataset(Dataset):
    def __init__(self, df, transform=None):
        self.paths  = df["full_path"].tolist()
        self.labels = df[CLOUD_COLS].to_numpy(dtype=np.float32)
        self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        img = Image.open(self.paths[i]).convert("RGB")
        if self.transform: img = self.transform(img)
        return img, torch.from_numpy(self.labels[i])

img_size = 224
train_tfms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1, hue=0.03),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
val_tfms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

val_frac = 0.2
val_size = int(len(df) * val_frac)
train_size = len(df) - val_size
g = torch.Generator().manual_seed(42)
indices = torch.randperm(len(df), generator=g).tolist()
train_idx, val_idx = indices[:train_size], indices[train_size:]

train_df = df.iloc[train_idx].reset_index(drop=True)
val_df   = df.iloc[val_idx].reset_index(drop=True)

train_ds = CloudDataset(train_df, transform=train_tfms)
val_ds   = CloudDataset(val_df,   transform=val_tfms)

batch_size = 64
use_cuda = torch.cuda.is_available()
train_dl = DataLoader(
    train_ds, batch_size=batch_size, shuffle=True,
    num_workers=4, pin_memory=use_cuda, persistent_workers=True, prefetch_factor=2
)
val_dl = DataLoader(
    val_ds, batch_size=batch_size, shuffle=False,
    num_workers=4, pin_memory=use_cuda, persistent_workers=True, prefetch_factor=2
)

print(f"Train: {len(train_ds)}  |  Val: {len(val_ds)}  |  Batch: {batch_size}")


Train: 8000  |  Val: 2000  |  Batch: 64




In [None]:
device = torch.device("cuda" if use_cuda else "cpu")
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
in_feats = model.fc.in_features
model.fc = nn.Linear(in_feats, num_classes)
model = model.to(device)

criterion  = nn.BCEWithLogitsLoss()
optimizer  = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler  = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)

torch.backends.cudnn.benchmark = True
scaler = GradScaler(enabled=use_cuda, device="cuda")

from sklearn.metrics import f1_score
@torch.no_grad()
def evaluate(model, dl, thresh=0.5):
    model.eval()
    total_loss, n_samples = 0.0, 0
    correct_per_label = torch.zeros(num_classes, dtype=torch.long)
    total_per_label   = torch.zeros(num_classes, dtype=torch.long)
    exact_correct = 0

    all_preds, all_labels = [], []

    for imgs, labels in dl:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        total_loss += loss.item() * imgs.size(0)
        n_samples  += imgs.size(0)

        preds = (logits.sigmoid() >= thresh).long()

        # accumulate for F1
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())

        correct_per_label += (preds == labels.long()).sum(dim=0).cpu()
        total_per_label   += torch.tensor([imgs.size(0)] * num_classes)
        exact_correct     += (preds.eq(labels.long()).all(dim=1)).sum().item()

    avg_loss = total_loss / max(1, n_samples)
    per_label_acc = (correct_per_label.float() / total_per_label.clamp(min=1).float()).tolist()
    exact_match = exact_correct / max(1, n_samples)

    # stack predictions + labels
    y_true = torch.cat(all_labels).numpy()
    y_pred = torch.cat(all_preds).numpy()

    f1_micro = f1_score(y_true, y_pred, average="micro", zero_division=0)
    f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0)

    return avg_loss, per_label_acc, exact_match, f1_micro, f1_macro

epochs = 20
best_val = float('inf')

for epoch in range(1, epochs+1):
    model.train()
    pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{epochs}", leave=False)
    running_loss, n_seen = 0.0, 0

    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with autocast("cuda", enabled=use_cuda, dtype=torch.float16):
            logits = model(imgs)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * imgs.size(0)
        n_seen += imgs.size(0)
        pbar.set_postfix(train_loss=f"{running_loss / max(1, n_seen):.4f}")

    val_loss, per_label_acc, exact_match, f1_micro, f1_macro = evaluate(model, val_dl, thresh=0.5)

    print(f"[Epoch {epoch}] val_loss={val_loss:.4f}  exact_match={exact_match:.4f} "
        f"f1_micro={f1_micro:.4f} f1_macro={f1_macro:.4f}")
    short = {lbl: f"{acc:.3f}" for lbl, acc in zip(CLOUD_COLS, per_label_acc)}
    print(" per-label acc:", short)

    if val_loss < best_val:
        best_val = val_loss
        torch.save({
            "model": model.state_dict(),
            "labels": CLOUD_COLS,
            "epoch": epoch,
            "val_loss": val_loss
        }, SAVE_PATH)

print(f"Saved final model to {SAVE_PATH}")

print("Done.")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 205MB/s]


Epoch 1/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 1] val_loss=0.3795  exact_match=0.2820 f1_micro=0.2461 f1_macro=0.1549
 per-label acc: {'altocumulus': '0.781', 'altostratus': '0.821', 'cirrocumulus': '0.847', 'cirrostratus': '0.823', 'cirrus': '0.809', 'cumulonimbus': '0.955', 'cumulus': '0.762', 'nimbostratus': '0.914', 'stratocumulus': '0.812', 'stratus': '0.899'}


Epoch 2/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 2] val_loss=0.3804  exact_match=0.2875 f1_micro=0.2849 f1_macro=0.1842
 per-label acc: {'altocumulus': '0.799', 'altostratus': '0.813', 'cirrocumulus': '0.848', 'cirrostratus': '0.817', 'cirrus': '0.809', 'cumulonimbus': '0.955', 'cumulus': '0.751', 'nimbostratus': '0.914', 'stratocumulus': '0.813', 'stratus': '0.899'}


Epoch 3/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 3] val_loss=0.3866  exact_match=0.2855 f1_micro=0.2699 f1_macro=0.1760
 per-label acc: {'altocumulus': '0.791', 'altostratus': '0.823', 'cirrocumulus': '0.836', 'cirrostratus': '0.818', 'cirrus': '0.804', 'cumulonimbus': '0.955', 'cumulus': '0.747', 'nimbostratus': '0.914', 'stratocumulus': '0.827', 'stratus': '0.899'}


Epoch 4/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 4] val_loss=0.4000  exact_match=0.2585 f1_micro=0.2918 f1_macro=0.1934
 per-label acc: {'altocumulus': '0.758', 'altostratus': '0.826', 'cirrocumulus': '0.837', 'cirrostratus': '0.821', 'cirrus': '0.789', 'cumulonimbus': '0.955', 'cumulus': '0.749', 'nimbostratus': '0.914', 'stratocumulus': '0.791', 'stratus': '0.900'}


Epoch 5/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 5] val_loss=0.4101  exact_match=0.2565 f1_micro=0.2960 f1_macro=0.2169
 per-label acc: {'altocumulus': '0.748', 'altostratus': '0.817', 'cirrocumulus': '0.824', 'cirrostratus': '0.815', 'cirrus': '0.791', 'cumulonimbus': '0.954', 'cumulus': '0.748', 'nimbostratus': '0.909', 'stratocumulus': '0.811', 'stratus': '0.896'}


Epoch 6/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 6] val_loss=0.4326  exact_match=0.2400 f1_micro=0.3026 f1_macro=0.2225
 per-label acc: {'altocumulus': '0.739', 'altostratus': '0.808', 'cirrocumulus': '0.835', 'cirrostratus': '0.801', 'cirrus': '0.791', 'cumulonimbus': '0.954', 'cumulus': '0.743', 'nimbostratus': '0.909', 'stratocumulus': '0.784', 'stratus': '0.887'}


Epoch 7/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 7] val_loss=0.4622  exact_match=0.2240 f1_micro=0.3200 f1_macro=0.2582
 per-label acc: {'altocumulus': '0.753', 'altostratus': '0.788', 'cirrocumulus': '0.822', 'cirrostratus': '0.760', 'cirrus': '0.775', 'cumulonimbus': '0.955', 'cumulus': '0.740', 'nimbostratus': '0.909', 'stratocumulus': '0.802', 'stratus': '0.889'}


Epoch 8/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 8] val_loss=0.4800  exact_match=0.2485 f1_micro=0.2857 f1_macro=0.2364
 per-label acc: {'altocumulus': '0.749', 'altostratus': '0.815', 'cirrocumulus': '0.805', 'cirrostratus': '0.812', 'cirrus': '0.790', 'cumulonimbus': '0.954', 'cumulus': '0.732', 'nimbostratus': '0.899', 'stratocumulus': '0.799', 'stratus': '0.890'}


Epoch 9/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 9] val_loss=0.5076  exact_match=0.2045 f1_micro=0.3357 f1_macro=0.2529
 per-label acc: {'altocumulus': '0.754', 'altostratus': '0.803', 'cirrocumulus': '0.821', 'cirrostratus': '0.794', 'cirrus': '0.751', 'cumulonimbus': '0.953', 'cumulus': '0.692', 'nimbostratus': '0.904', 'stratocumulus': '0.797', 'stratus': '0.886'}


Epoch 10/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 10] val_loss=0.5289  exact_match=0.2310 f1_micro=0.2931 f1_macro=0.2330
 per-label acc: {'altocumulus': '0.783', 'altostratus': '0.784', 'cirrocumulus': '0.823', 'cirrostratus': '0.781', 'cirrus': '0.768', 'cumulonimbus': '0.953', 'cumulus': '0.731', 'nimbostratus': '0.895', 'stratocumulus': '0.809', 'stratus': '0.897'}


Epoch 11/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 11] val_loss=0.5493  exact_match=0.2210 f1_micro=0.3298 f1_macro=0.2673
 per-label acc: {'altocumulus': '0.769', 'altostratus': '0.785', 'cirrocumulus': '0.811', 'cirrostratus': '0.777', 'cirrus': '0.769', 'cumulonimbus': '0.951', 'cumulus': '0.719', 'nimbostratus': '0.900', 'stratocumulus': '0.804', 'stratus': '0.880'}


Epoch 12/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 12] val_loss=0.5511  exact_match=0.2345 f1_micro=0.3160 f1_macro=0.2488
 per-label acc: {'altocumulus': '0.748', 'altostratus': '0.805', 'cirrocumulus': '0.793', 'cirrostratus': '0.809', 'cirrus': '0.772', 'cumulonimbus': '0.951', 'cumulus': '0.740', 'nimbostratus': '0.904', 'stratocumulus': '0.795', 'stratus': '0.881'}


Epoch 13/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 13] val_loss=0.5741  exact_match=0.2300 f1_micro=0.3119 f1_macro=0.2442
 per-label acc: {'altocumulus': '0.771', 'altostratus': '0.803', 'cirrocumulus': '0.811', 'cirrostratus': '0.800', 'cirrus': '0.777', 'cumulonimbus': '0.951', 'cumulus': '0.728', 'nimbostratus': '0.906', 'stratocumulus': '0.780', 'stratus': '0.888'}


Epoch 14/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 14] val_loss=0.6047  exact_match=0.2180 f1_micro=0.3376 f1_macro=0.2731
 per-label acc: {'altocumulus': '0.748', 'altostratus': '0.813', 'cirrocumulus': '0.794', 'cirrostratus': '0.793', 'cirrus': '0.767', 'cumulonimbus': '0.938', 'cumulus': '0.697', 'nimbostratus': '0.904', 'stratocumulus': '0.800', 'stratus': '0.893'}


Epoch 15/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 15] val_loss=0.6181  exact_match=0.2400 f1_micro=0.2991 f1_macro=0.2345
 per-label acc: {'altocumulus': '0.772', 'altostratus': '0.799', 'cirrocumulus': '0.832', 'cirrostratus': '0.772', 'cirrus': '0.782', 'cumulonimbus': '0.951', 'cumulus': '0.731', 'nimbostratus': '0.908', 'stratocumulus': '0.811', 'stratus': '0.887'}


Epoch 16/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 16] val_loss=0.6254  exact_match=0.2075 f1_micro=0.3357 f1_macro=0.2712
 per-label acc: {'altocumulus': '0.743', 'altostratus': '0.791', 'cirrocumulus': '0.815', 'cirrostratus': '0.767', 'cirrus': '0.758', 'cumulonimbus': '0.952', 'cumulus': '0.723', 'nimbostratus': '0.898', 'stratocumulus': '0.777', 'stratus': '0.891'}


Epoch 17/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 17] val_loss=0.6342  exact_match=0.2440 f1_micro=0.3202 f1_macro=0.2519
 per-label acc: {'altocumulus': '0.774', 'altostratus': '0.795', 'cirrocumulus': '0.828', 'cirrostratus': '0.808', 'cirrus': '0.798', 'cumulonimbus': '0.945', 'cumulus': '0.730', 'nimbostratus': '0.909', 'stratocumulus': '0.782', 'stratus': '0.895'}


Epoch 18/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 18] val_loss=0.6365  exact_match=0.2495 f1_micro=0.2946 f1_macro=0.2410
 per-label acc: {'altocumulus': '0.770', 'altostratus': '0.787', 'cirrocumulus': '0.815', 'cirrostratus': '0.805', 'cirrus': '0.789', 'cumulonimbus': '0.948', 'cumulus': '0.739', 'nimbostratus': '0.904', 'stratocumulus': '0.796', 'stratus': '0.891'}


Epoch 19/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 19] val_loss=0.6427  exact_match=0.2000 f1_micro=0.3444 f1_macro=0.2837
 per-label acc: {'altocumulus': '0.732', 'altostratus': '0.805', 'cirrocumulus': '0.772', 'cirrostratus': '0.761', 'cirrus': '0.745', 'cumulonimbus': '0.943', 'cumulus': '0.715', 'nimbostratus': '0.887', 'stratocumulus': '0.807', 'stratus': '0.883'}


Epoch 20/20:   0%|          | 0/125 [00:00<?, ?it/s]

[Epoch 20] val_loss=0.6482  exact_match=0.2295 f1_micro=0.3025 f1_macro=0.2400
 per-label acc: {'altocumulus': '0.758', 'altostratus': '0.806', 'cirrocumulus': '0.827', 'cirrostratus': '0.789', 'cirrus': '0.782', 'cumulonimbus': '0.950', 'cumulus': '0.726', 'nimbostratus': '0.899', 'stratocumulus': '0.794', 'stratus': '0.891'}
Saved final model to /content/drive/MyDrive/clouds-ml/models/resnet18_big10k.pth
Done.
