## Dataset preparation


In [None]:
# -------------------------------------------
# a) Torch imports
# -------------------------------------------
import os, glob, re, random, math, copy, time
import numpy as np
from PIL import Image
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

# -------------------------------------------
# b) label‑making rule  (adapted thresholds)
# -------------------------------------------
def to_label(ctr, impr):
    if impr < 1e5 and ctr < 0.10:      # low
        return 0
    if impr >= 3e5 or ctr >= 0.20:     # high
        return 2
    return 1                           # medium

# -------------------------------------------
# c) ordinary image dataset
# -------------------------------------------
class AdImageDataset(Dataset):
    _pattern = re.compile(r'([\d.]+)_([\d]+)_.+\.jpg$', re.I)
    def __init__(self, root, transform=None):
        self.fpaths = glob.glob(os.path.join(root, '*.jpg'))
        self.transform = transform
        self.samples = []
        for p in self.fpaths:
            m = self._pattern.search(os.path.basename(p))
            if not m: continue
            ctr  = float(m.group(1))
            impr = float(m.group(2))
            label = to_label(ctr, impr)
            self.samples.append((p, label))

        self.labels  = [lbl for _, lbl in self.samples]
        self.indices_to_labels = lambda idx: [self.labels[i]
                                              for i in idx]
    def __len__(self):  return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')
        if self.transform:  img = self.transform(img)
        return img, label

# common transform
img_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])


In [None]:

!pip install --no-deps torchmeta==1.8.0



Collecting torchmeta==1.8.0
  Downloading torchmeta-1.8.0-py3-none-any.whl.metadata (8.2 kB)
Downloading torchmeta-1.8.0-py3-none-any.whl (210 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/210.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.4/210.4 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchmeta
Successfully installed torchmeta-1.8.0


In [None]:
!pip install ordered-set

Collecting ordered-set
  Downloading ordered_set-4.1.0-py3-none-any.whl.metadata (5.3 kB)
Downloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)
Installing collected packages: ordered-set
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchmeta 1.8.0 requires torch<1.10.0,>=1.4.0, but you have torch 2.6.0+cu124 which is incompatible.
torchmeta 1.8.0 requires torchvision<0.11.0,>=0.5.0, but you have torchvision 0.21.0+cu124 which is incompatible.[0m[31m
[0mSuccessfully installed ordered-set-4.1.0


In [None]:
import torchvision.datasets.utils as tv_utils

# torchmeta 1.x expects these, but newer torchvision has removed them
def _get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value
    return None

def _save_response_content(response, destination, chunk_size=32768):
    with open(destination, "wb") as f:
        for chunk in response.iter_content(chunk_size):
            if chunk:
                f.write(chunk)

# Monkey-patch them into torchvision
tv_utils._get_confirm_token    = _get_confirm_token
tv_utils._save_response_content = _save_response_content


In [None]:
!find /content -maxdepth 2 -type d

/content
/content/.config
/content/.config/logs
/content/.config/configurations
/content/sample_data
/content/sample_data/.ipynb_checkpoints


## Prototypical Network (episodic training, TorchMeta)


In [None]:
from torchmeta.utils.data.task import Dataset as TMTaskDataset

def __getitem__(self, class_idx):
    paths = self.class_to_imgs[self.classes[class_idx]]

    class OneClassDataset(TMTaskDataset):
        def __init__(self, paths, transform, index):
                # call the torchmeta Dataset constructor so
                # .index, .transform and .target_transform_append are there
                super().__init__(index, transform=transform, target_transform=None)
                self.paths = paths

        def __len__(self):
                return len(self.paths)

        def __getitem__(self, i):
                img = Image.open(self.paths[i]).convert('RGB')
                #  transform already on self.transform
                return self.transform(img), class_idx

    return OneClassDataset(paths, self.transform, class_idx)


In [None]:
# ─── Imports & setup ────────────────────────────────────────
import random
import torch
import torch.nn.functional as F
from torch import nn
from torchmeta.utils.data import ClassDataset, CombinationMetaDataset, BatchMetaDataLoader
from torchmeta.transforms import ClassSplitter
from torchmeta.utils.data.task import Dataset as TMTaskDataset
from torch.utils.data import Dataset
from torchvision import models
from PIL import Image
from collections import defaultdict
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

N_way, K_shot, Q_query = 3, 5, 10
num_epochs = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── AdClassDataset subclass ───────────────────────────────────────────
class AdClassDataset(ClassDataset):
    def __init__(self, root, transform, meta_train=True):
        super().__init__(meta_train=meta_train)
        self.transform = transform
        base = AdImageDataset(root, transform)
        groups = defaultdict(list)
        for p, l in base.samples:
            groups[l].append(p)
        self.classes = list(groups.keys())
        self.class_to_imgs = groups
    @property
    def num_classes(self):
        return len(self.classes)
    def __len__(self):
        return self.num_classes
    def __getitem__(self, class_idx):
        paths = self.class_to_imgs[self.classes[class_idx]]
        class OneClassDataset(TMTaskDataset):
            def __init__(self, paths, transform, index):
                super().__init__(index, transform=transform, target_transform=None)
                self.paths = paths
            def __len__(self):
                return len(self.paths)
            def __getitem__(self, i):
                img = Image.open(self.paths[i]).convert('RGB')
                return self.transform(img), class_idx
        return OneClassDataset(paths, self.transform, class_idx)

# ─── Build meta‐dataset & loader ────────────────────────────────────────────
dataset_transform = ClassSplitter(shuffle=True,
                                 num_train_per_class=K_shot,
                                 num_test_per_class=Q_query)
meta_ds = CombinationMetaDataset(
    AdClassDataset('/content/sample_data', img_tf),
    num_classes_per_task=N_way,
    dataset_transform=dataset_transform
)
loader = BatchMetaDataLoader(meta_ds,
    batch_size=1, shuffle=True, num_workers=2, pin_memory=True
)

# ─── ProtoNet & optimizer ──────────────────────────────────────────────────
class ProtoNet(nn.Module):
    def __init__(self, emb_dim=512):
        super().__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.backbone.fc = nn.Identity()
    def forward(self, x):
        return self.backbone(x)

model = ProtoNet().to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-4)

# ─── Training loop with best‐model tracking ─────────────────────────────────
best_r2 = float('-inf')
best_metrics = {}
best_state = None

for epoch in range(1, num_epochs+1):
    model.train()
    running_loss = 0.0
    all_trues, all_preds = [], []

    for batch in loader:
        # unpack and reshape
        (support_x, support_y) = batch['train']
        ( query_x,   query_y) = batch['test']
        support_x = support_x.squeeze(0).to(device)
        support_y = support_y.squeeze(0).to(device)
        query_x   = query_x.squeeze(0).to(device)
        query_y   = query_y.squeeze(0).to(device)

        # forward / loss
        emb_sup = model(support_x)
        emb_qry = model(query_x)
        protos = torch.stack([emb_sup[support_y==c].mean(0)
                               for c in range(N_way)])
        dists  = ((emb_qry.unsqueeze(1)-protos)**2).sum(-1)
        logits = -dists
        loss   = F.cross_entropy(logits, query_y)

        optim.zero_grad()
        loss.backward()
        optim.step()

        running_loss += loss.item()
        preds = logits.argmax(1).cpu().tolist()
        trues = query_y.cpu().tolist()
        all_preds.extend(preds)
        all_trues.extend(trues)

    # compute metrics this epoch
    mae  = mean_absolute_error(all_trues, all_preds)
    mse  = mean_squared_error(all_trues, all_preds)
    rmse = mse**0.5
    r2   = r2_score(all_trues, all_preds)
    avg_loss = running_loss / len(loader)

    print(f"Epoch {epoch:02d} — loss {avg_loss:.4f}  "
          f"MAE {mae:.4f}  MSE {mse:.4f}  RMSE {rmse:.4f}  R² {r2:.4f}")

    # track best by highest R²
    if r2 > best_r2:
        best_r2 = r2
        best_metrics = dict(MAE=mae, MSE=mse, RMSE=rmse, R2=r2, loss=avg_loss)
        best_state = model.state_dict()

# ─── Final-best metrics ─────────────────────────────────────────────────────
print("\n>>> Best model metrics:")
print(f" MAE:  {best_metrics['MAE']:.4f}")
print(f" MSE:  {best_metrics['MSE']:.4f}")
print(f" RMSE: {best_metrics['RMSE']:.4f}")
print(f" R²:   {best_metrics['R2']:.4f}")
print(f" loss: {best_metrics['loss']:.4f}")

# load best weights back into the model
model.load_state_dict(best_state)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 116MB/s]


Epoch 01 — loss 14.6464  MAE 0.5333  MSE 0.8667  RMSE 0.9309  R² -0.3000
Epoch 02 — loss 21.1344  MAE 0.7333  MSE 1.2000  RMSE 1.0954  R² -0.8000
Epoch 03 — loss 2.6205  MAE 0.1000  MSE 0.1000  RMSE 0.3162  R² 0.8500
Epoch 04 — loss 8.5375  MAE 0.5000  MSE 0.8333  RMSE 0.9129  R² -0.2500
Epoch 05 — loss 9.8611  MAE 0.5000  MSE 0.7667  RMSE 0.8756  R² -0.1500
Epoch 06 — loss 0.9096  MAE 0.1667  MSE 0.2333  RMSE 0.4830  R² 0.6500
Epoch 07 — loss 6.8392  MAE 0.3333  MSE 0.4667  RMSE 0.6831  R² 0.3000
Epoch 08 — loss 4.9625  MAE 0.0333  MSE 0.0333  RMSE 0.1826  R² 0.9500
Epoch 09 — loss 1.6025  MAE 0.1333  MSE 0.2000  RMSE 0.4472  R² 0.7000
Epoch 10 — loss 6.1206  MAE 0.2333  MSE 0.3000  RMSE 0.5477  R² 0.5500
Epoch 11 — loss 2.1177  MAE 0.1000  MSE 0.1667  RMSE 0.4082  R² 0.7500
Epoch 12 — loss 5.2206  MAE 0.2333  MSE 0.4333  RMSE 0.6583  R² 0.3500
Epoch 13 — loss 0.0004  MAE 0.0000  MSE 0.0000  RMSE 0.0000  R² 1.0000
Epoch 14 — loss 0.0004  MAE 0.0000  MSE 0.0000  RMSE 0.0000  R² 1.0000


<All keys matched successfully>

In [None]:

dataset = AdClassDataset('/content/sample_data', img_tf)

# 1) Number of classes
print(f"Total classes: {dataset.num_classes}")

# 2) Samples per class
total = 0
for cls in dataset.classes:
    n = len(dataset.class_to_imgs[cls])
    print(f"  Class {cls!r} has {n} samples")
    total += n

print(f"Total samples across all classes: {total}")


Total classes: 3
  Class 2 has 487 samples
  Class 1 has 142 samples
  Class 0 has 28 samples
Total samples across all classes: 657


In [None]:
from google.colab import drive
drive.mount('/content/drive')

## MAML (Model‑Agnostic Meta‑Learning with learn2learn)

In [None]:
!pip install --upgrade pip setuptools wheel cython
!git clone https://github.com/learnables/learn2learn.git
%cd learn2learn
!pip install -e .
%cd ..


fatal: destination path 'learn2learn' already exists and is not an empty directory.
/content/learn2learn
Obtaining file:///content/learn2learn
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: learn2learn
  Attempting uninstall: learn2learn
    Found existing installation: learn2learn 0.2.1
    Uninstalling learn2learn-0.2.1:
      Successfully uninstalled learn2learn-0.2.1
[33m  DEPRECATION: Legacy editable install of learn2learn==0.2.1 from file:///content/learn2learn (setup.py develop) is deprecated. pip 25.3 will enforce this behaviour change. A possible replacement is to add a pyproject.toml or enable --use-pep517, and use setuptools >= 64. If the resulting installation is not behaving as expected, try using --config-settings editable_mode=compat. Please consult the setuptools documentation for more information. Discussion can be found at https://github.com/pypa/pip/issues/11457[0m[33m
[0m  Running setup.py develop for learn2learn
Successfully

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from collections import defaultdict
import random
import copy

# Custom MAML implementation
class MAML:
    def __init__(self, model, lr=1e-3, first_order=True):
        self.model = model
        self.lr = lr
        self.first_order = first_order

    def clone(self):
        # Create a deep copy of the model
        cloned_model = copy.deepcopy(self.model)
        return MAMLLearner(cloned_model, self.lr, self.first_order)

    def parameters(self):
        return self.model.parameters()

class MAMLLearner:
    def __init__(self, model, lr, first_order):
        self.model = model
        self.lr = lr
        self.first_order = first_order

    def __call__(self, x):
        return self.model(x)

    def adapt(self, loss):
        # Compute gradients
        grads = torch.autograd.grad(
            loss,
            self.model.parameters(),
            create_graph=not self.first_order,
            retain_graph=not self.first_order
        )

        # Update parameters
        for param, grad in zip(self.model.parameters(), grads):
            if grad is not None:
                param.data = param.data - self.lr * grad

# Manual task creation approach
class ManualTaskSampler:
    def __init__(self, dataset, n_ways=3, k_shots=5, query_shots=10):
        self.dataset = dataset
        self.n_ways = n_ways
        self.k_shots = k_shots
        self.query_shots = query_shots

        # Group samples by class
        self.class_indices = defaultdict(list)
        for idx in range(len(dataset)):
            _, label = dataset[idx]
            self.class_indices[label].append(idx)

        self.classes = list(self.class_indices.keys())
        print(f"Found {len(self.classes)} classes with samples: {[len(v) for v in self.class_indices.values()]}")

        # Check if we have enough classes
        if len(self.classes) < self.n_ways:
            raise ValueError(f"Dataset has only {len(self.classes)} classes, but n_ways={self.n_ways}")

    def sample_task(self):
        # Randomly select n_ways classes
        selected_classes = random.sample(self.classes, self.n_ways)

        support_data, support_labels = [], []
        query_data, query_labels = [], []

        for new_label, original_class in enumerate(selected_classes):
            # Get all indices for this class
            class_indices = self.class_indices[original_class]

            # Sample k_shots + query_shots examples
            n_samples_needed = self.k_shots + self.query_shots
            if len(class_indices) < n_samples_needed:
                # If not enough samples, sample with replacement
                sampled_indices = random.choices(class_indices, k=n_samples_needed)
            else:
                sampled_indices = random.sample(class_indices, n_samples_needed)

            # Split into support and query
            support_indices = sampled_indices[:self.k_shots]
            query_indices = sampled_indices[self.k_shots:self.k_shots + self.query_shots]

            # Get support data
            for idx in support_indices:
                data, _ = self.dataset[idx]
                support_data.append(data)
                support_labels.append(new_label)  # remapped label

            # Get query data
            for idx in query_indices:
                data, _ = self.dataset[idx]
                query_data.append(data)
                query_labels.append(new_label)  # remapped label

        # Convert to tensors
        support_data = torch.stack(support_data)
        support_labels = torch.tensor(support_labels)
        query_data = torch.stack(query_data)
        query_labels = torch.tensor(query_labels)

        # Combine support and query
        all_data = torch.cat([support_data, query_data], dim=0)
        all_labels = torch.cat([support_labels, query_labels], dim=0)

        return all_data, all_labels

# Usage with the dataset
base_ds = AdImageDataset('/content/sample_data', img_tf)
task_sampler = ManualTaskSampler(base_ds, n_ways=3, k_shots=5, query_shots=10)

# Define model
class ConvClassifier(nn.Module):
    def __init__(self, n_out=3):
        super().__init__()
        self.backbone = models.resnet18(weights=None)
        self.backbone.fc = nn.Linear(512, n_out)

    def forward(self, x):
        return self.backbone(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ConvClassifier().to(device)

# Wrap in MAML
maml = MAML(model, lr=1e-3, first_order=True)
opt = torch.optim.Adam(maml.parameters(), lr=1e-3)

print("Starting meta-training with custom MAML implementation...")

# Training loop
for epoch in range(1, 1001):
    try:
        # Sample a task manually
        data, labels = task_sampler.sample_task()
        data, labels = data.to(device), labels.to(device)

        # Split into support and query
        n_support = 3 * 5  # n_ways * k_shots
        support_x, support_y = data[:n_support], labels[:n_support]
        query_x, query_y = data[n_support:], labels[n_support:]

        # MAML adaptation
        learner = maml.clone()
        support_logits = learner(support_x)
        loss_s = F.cross_entropy(support_logits, support_y)
        learner.adapt(loss_s)

        # Query loss
        query_logits = learner(query_x)
        loss_q = F.cross_entropy(query_logits, query_y)

        # Meta-update
        opt.zero_grad()
        loss_q.backward()
        opt.step()

        if epoch % 100 == 0:
            with torch.no_grad():
                preds = query_logits.argmax(dim=1).cpu().numpy()
                trues = query_y.cpu().numpy()

                acc = (preds == trues).mean()
                mae = mean_absolute_error(trues, preds)
                mse = mean_squared_error(trues, preds)
                rmse = np.sqrt(mse)
                r2 = r2_score(trues, preds)

                print(f"Epoch {epoch:4d}  Loss {loss_q.item():.4f}  "
                      f"Acc {acc:.3f}  MAE {mae:.4f}  MSE {mse:.4f}  "
                      f"RMSE {rmse:.4f}  R² {r2:.4f}")

    except Exception as e:
        print(f"Error at epoch {epoch}: {e}")
        import traceback
        traceback.print_exc()
        break

print("Training completed!")

# Function to evaluate on new tasks
def evaluate_few_shot(maml, task_sampler, n_test_tasks=50):
    """Evaluate the meta-learned model on new tasks"""
    test_accuracies = []

    print(f"Evaluating on {n_test_tasks} test tasks...")

    for i in range(n_test_tasks):
        try:
            # Sample a test task
            data, labels = task_sampler.sample_task()
            data, labels = data.to(device), labels.to(device)

            n_support = 3 * 5
            support_x, support_y = data[:n_support], labels[:n_support]
            query_x, query_y = data[n_support:], labels[n_support:]

            # Clone and adapt
            learner = maml.clone()
            support_logits = learner(support_x)
            loss_s = F.cross_entropy(support_logits, support_y)
            learner.adapt(loss_s)

            # Test on query set
            with torch.no_grad():
                query_logits = learner(query_x)
                preds = query_logits.argmax(dim=1)
                acc = (preds == query_y).float().mean().item()
                test_accuracies.append(acc)

        except Exception as e:
            print(f"Error in test task {i}: {e}")
            continue

    if test_accuracies:
        mean_acc = np.mean(test_accuracies)
        std_acc = np.std(test_accuracies)
        print(f"Test Performance: {mean_acc:.3f} ± {std_acc:.3f}")
        return mean_acc, std_acc
    else:
        print("No successful test tasks!")
        return 0.0, 0.0

# evaluate_few_shot(maml, task_sampler)

Found 3 classes with samples: [487, 142, 28]
Starting meta-training with custom MAML implementation...
Epoch  100  Loss 1.0650  Acc 0.433  MAE 0.6333  MSE 0.7667  RMSE 0.8756  R² -0.1500
Epoch  200  Loss 1.1031  Acc 0.300  MAE 0.7333  MSE 0.8000  RMSE 0.8944  R² -0.2000
Epoch  300  Loss 1.1237  Acc 0.367  MAE 0.7000  MSE 0.8333  RMSE 0.9129  R² -0.2500
Epoch  400  Loss 1.1152  Acc 0.367  MAE 0.7333  MSE 0.9333  RMSE 0.9661  R² -0.4000
Epoch  500  Loss 1.1002  Acc 0.300  MAE 0.8667  MSE 1.2000  RMSE 1.0954  R² -0.8000
Epoch  600  Loss 1.1109  Acc 0.300  MAE 0.8333  MSE 1.1000  RMSE 1.0488  R² -0.6500
Epoch  700  Loss 1.1371  Acc 0.367  MAE 0.6667  MSE 0.7333  RMSE 0.8563  R² -0.1000
Epoch  800  Loss 1.1493  Acc 0.267  MAE 0.8000  MSE 0.9333  RMSE 0.9661  R² -0.4000
Epoch  900  Loss 1.1054  Acc 0.500  MAE 0.5000  MSE 0.5000  RMSE 0.7071  R² 0.2500
Epoch 1000  Loss 1.0886  Acc 0.300  MAE 0.8333  MSE 1.1000  RMSE 1.0488  R² -0.6500
Training completed!


## Few‑Shot Linear‑Probe (1‑line classification head) - decided not to use in the end


In [None]:
from torchvision.models import vit_b_16, ViT_B_16_Weights
weights = ViT_B_16_Weights.DEFAULT
backbone = vit_b_16(weights=weights)
backbone.heads = nn.Identity()       # remove classifier
for p in backbone.parameters():      # freeze
    p.requires_grad = False


few_ds = ...  # subset AdImageDataset with 15 images total
few_loader = DataLoader(few_ds, batch_size=15)

# extract embeddings (one batch)
imgs, lbs = next(iter(few_loader))
embs = backbone(imgs.to(device)).cpu()

# train logistic regression on these 15 vectors
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(max_iter=1000).fit(embs, lbs)

# inference
val_ds  = AdImageDataset(r'C:\Bakalauras\downloaded_images\val', img_tf)
val_emb = []
val_lab = []
for img,l in DataLoader(val_ds, batch_size=32):
    with torch.no_grad():
        val_emb.append(backbone(img.to(device)).cpu())
        val_lab.append(l)
val_emb = torch.cat(val_emb).numpy(); val_lab = torch.cat(val_lab).numpy()
pred = clf.predict(val_emb)

from sklearn.metrics import classification_report
print(classification_report(val_lab, pred, digits=4))


In [None]:
# ==========================================================
# 1. Imports (torch 1.8.1, torchvision 0.9.1, torchmeta 1.7.0)
# ==========================================================
import os, glob, re, random
from PIL import Image
import torch, torch.nn as nn, torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset
from torchmeta.utils.data import (
    ClassDataset, CombinationMetaDataset, BatchMetaDataLoader,
    Dataset as MetaDataset,
)

# ----------------------------------------------------------
# 2. Simple (ctr, impressions) → class label
# ----------------------------------------------------------
def to_label(ctr, impr):
    if impr < 1e5 and ctr < 0.10:
        return 0          # low
    if impr >= 3e5 or ctr >= 0.20:
        return 2          # high
    return 1              # medium

# ----------------------------------------------------------
# 3. Plain banner-image dataset
# ----------------------------------------------------------
class AdImageDataset(Dataset):
    _pat = re.compile(r'([\d.]+)_([\d]+)_.+\.jpg$', re.I)

    def __init__(self, root, transform):
        self.transform = transform
        self.samples   = []               # (path, lbl)
        for p in glob.glob(os.path.join(root, '*.jpg')):
            m = self._pat.search(os.path.basename(p))
            if not m:
                continue
            ctr, impr = float(m.group(1)), float(m.group(2))
            lbl = to_label(ctr, impr)
            self.samples.append((p, lbl))
        self.labels  = [lbl for _, lbl in self.samples]  # torchmeta looks here
        self.targets = self.labels

    def __len__(self):  return len(self.samples)

    def __getitem__(self, idx):
        path, lbl = self.samples[idx]
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
        return img, lbl

# ----------------------------------------------------------
# 4. Per-class dataset that torchmeta 1.7.0 can use
# ----------------------------------------------------------

from torchmeta.utils.data import Dataset as MetaDataset

class ClassImagesDataset(MetaDataset):
    """Dataset containing all images of ONE class."""
    def __init__(self, paths, transform, cls_idx):
        # supply required positional index and the transform
        super().__init__(cls_idx, transform=transform, meta_split='train')
        self.paths = paths            # list of file paths
        self.index = cls_idx          # class label

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        img = self.transform(img)
        return img, self.index


# ----------------------------------------------------------
# 5. ClassDataset wrapper
# ----------------------------------------------------------
class AdClassDataset(ClassDataset):
    def __init__(self, root, transform):
        # tell torchmeta this is the meta-training split
        super().__init__(meta_train=True)

        base = AdImageDataset(root, transform)
        self.cls_paths = {0: [], 1: [], 2: []}
        for p, l in base.samples:
            self.cls_paths[l].append(p)

        self.classes   = list(self.cls_paths.keys())
        self.transform = transform

    def __len__(self):  return len(self.classes)

    def __getitem__(self, idx):
        cls = self.classes[idx]
        return ClassImagesDataset(self.cls_paths[cls], self.transform, cls)


# ----------------------------------------------------------
# 6. Image transforms
# ----------------------------------------------------------
img_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# ==========================================================
# 7. Meta-dataset & loader (num_workers=0 on Windows)
# ==========================================================
root_dir = r'C:\Bakalauras\downloaded_images'
N_way, K_shot, Q_query = 3, 5, 10

cls_dataset = AdClassDataset(root_dir, img_tf)
meta_ds = CombinationMetaDataset(cls_dataset,
                                 num_classes_per_task=N_way)

loader = BatchMetaDataLoader(meta_ds,
                             batch_size=1,
                             shuffle=True,
                             num_workers=0)

# ==========================================================
# 8. ProtoNet backbone (torchvision 0.9.1)
# ==========================================================
class ProtoNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet18(pretrained=True)  # old API
        self.backbone.fc = nn.Identity()                  # 512-d output
    def forward(self, x):
        return self.backbone(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model  = ProtoNet().to(device)
optim  = torch.optim.Adam(model.parameters(), lr=1e-4)

# ==========================================================
# 9. Episodic training loop
# ==========================================================
def load_stack(ds, idxs):
    return torch.stack([ds[i][0] for i in idxs])

for epoch in range(5):
    for episode in loader:                   # episode = ConcatTask
        support_x, query_x = [], []
        support_y, query_y = [], []

        # build support/query sets
        ok = True
        for cls_idx, class_ds in enumerate(episode.datasets):
            if len(class_ds) < K_shot + Q_query:
                ok = False; break
            perm = torch.randperm(len(class_ds))
            sup_idx = perm[:K_shot]
            qry_idx = perm[K_shot:K_shot+Q_query]

            support_x.append(load_stack(class_ds, sup_idx))
            query_x.append(load_stack(class_ds, qry_idx))
            support_y += [cls_idx] * K_shot
            query_y   += [cls_idx] * Q_query

        if not ok:
            continue

        support_x = torch.cat(support_x).to(device)
        query_x   = torch.cat(query_x).to(device)
        support_y = torch.tensor(support_y, device=device)
        query_y   = torch.tensor(query_y,   device=device)

        # embeddings
        emb_sup = model(support_x)
        emb_qry = model(query_x)

        # prototypes
        protos = torch.stack([emb_sup[support_y==c].mean(0)
                              for c in range(N_way)])

        # distances & loss
        logits = -((emb_qry.unsqueeze(1) - protos)**2).sum(-1)
        loss   = F.cross_entropy(logits, query_y)

        optim.zero_grad()
        loss.backward()
        optim.step()

    print(f'Epoch {epoch+1}: episode loss {loss.item():.4f}')


### Augmentations Experiment


In [None]:
import os, glob, re
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import transforms, models
from PIL import Image

from sklearn.metrics import classification_report, mean_absolute_error, mean_squared_error, r2_score

# 1) Dataset for 3-way classification (viral categories)
class ViralAdsDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.transform = transform
        self.samples = []
        pattern = re.compile(r"^([\d.]+)_([\d]+)_.+\.jpg$")
        for path in glob.glob(os.path.join(folder, "*.jpg")):
            fn = os.path.basename(path)
            m = pattern.match(fn)
            if not m:
                continue
            ctr = float(m.group(1))
            impr = float(m.group(2))
            # label: 0=no,1=moderate,2=viral
            if impr < 1e5 and ctr < 0.1:
                label = 0
            elif impr >= 3e5 or ctr >= 0.2:
                label = 2
            else:
                label = 1
            self.samples.append((path, label))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

# 2) Strong augmentations for training, simple for validation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

# 3) Prepare data loaders
dataset = ViralAdsDataset(r"/content/sample_data", transform=None)
n_train = int(0.8 * len(dataset))
n_val = len(dataset) - n_train
train_ds, val_ds = random_split(dataset, [n_train, n_val])
train_ds.dataset.transform = train_transform
val_ds.dataset.transform = val_transform

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False, num_workers=4)

# 4) Model: ResNet-18 fine-tuning head only
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(pretrained=True)
for name,param in model.named_parameters():
    if not name.startswith("fc"):
        param.requires_grad = False
model.fc = nn.Linear(model.fc.in_features, 3)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

# 5) Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
    print(f"Epoch {epoch+1}/{num_epochs}  Train Loss: {total_loss/len(train_loader.dataset):.4f}")

# 6) Evaluation
model.eval()
all_preds, all_trues = [], []
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs = imgs.to(device)
        logits = model(imgs)
        preds = logits.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_trues.extend(labels.numpy())

# 7) Metrics
print(classification_report(all_trues, all_preds, digits=4))
mae  = mean_absolute_error(all_trues, all_preds)
mse  = mean_squared_error(all_trues, all_preds)
rmse = np.sqrt(mse)
r2   = r2_score(all_trues, all_preds)

print(f"MAE:  {mae:.4f}")
print(f"MSE:  {mse:.4f}")
print(f"RMSE: {rmse:.4f}")
print(f"R²:   {r2:.4f}")


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 105MB/s]


Epoch 1/20  Train Loss: 1.1850
Epoch 2/20  Train Loss: 0.8347
Epoch 3/20  Train Loss: 0.7534
Epoch 4/20  Train Loss: 0.7223
Epoch 5/20  Train Loss: 0.7046
Epoch 6/20  Train Loss: 0.6964
Epoch 7/20  Train Loss: 0.6965
Epoch 8/20  Train Loss: 0.6749
Epoch 9/20  Train Loss: 0.6679
Epoch 10/20  Train Loss: 0.6592
Epoch 11/20  Train Loss: 0.6561
Epoch 12/20  Train Loss: 0.6516
Epoch 13/20  Train Loss: 0.6472
Epoch 14/20  Train Loss: 0.6349
Epoch 15/20  Train Loss: 0.6284
Epoch 16/20  Train Loss: 0.6290
Epoch 17/20  Train Loss: 0.6218
Epoch 18/20  Train Loss: 0.6151
Epoch 19/20  Train Loss: 0.6058
Epoch 20/20  Train Loss: 0.6083
              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         5
           1     0.4444    0.1905    0.2667        21
           2     0.8211    0.9528    0.8821       106

    accuracy                         0.7955       132
   macro avg     0.4219    0.3811    0.3829       132
weighted avg     0.7301    0.7955    0.7508

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
