In [7]:
#### ============================================
# CIFAR-100-LT + ResNet-32
# Balanced Softmax + (2.7 - f_c)^(0.5 - p_t) weighting
# Single-stage, 13k iterations, warmup + cosine
# Augmentation as in EQL / BALMS (AutoAug + Cutout)
# ============================================

import os
import math
import json
import random
import pickle
from collections import Counter

import numpy as np
from PIL import Image, ImageEnhance, ImageOps

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# ----------------------------
#  Config
# ----------------------------

IMB_FACTOR = 100.0       # imbalance ratio
BATCH_SIZE = 512

TOTAL_ITERS = 13000
WARMUP_ITERS = 800

BASE_LR = 0.05           # start of warmup
WARMUP_LR = 0.1          # target LR after warmup
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPS = 1e-6
SEED = 42

# >>> SET THIS to your CIFAR-100 pickle folder <<<
# Folder must contain "train" and "test" pickles like standard CIFAR-100
DATA_ROOT = "/kaggle/input/hihihi"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

# =========================================================
#  Cutout
# =========================================================

class Cutout(object):
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        # img: Tensor [C, H, W]
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for _ in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1:y2, x1:x2] = 0.

        mask = torch.from_numpy(mask).to(img.device if img.is_cuda else "cpu", img.dtype)
        mask = mask.expand_as(img)
        img = img * mask
        return img

# =========================================================
#  AutoAugment CIFAR10 Policy (as used in EQL / BALMS)
# =========================================================

class SubPolicy(object):
    def __init__(self, p1, operation1, magnitude_idx1,
                 p2, operation2, magnitude_idx2,
                 fillcolor=(128, 128, 128)):
        ranges = {
            "shearX": np.linspace(0, 0.3, 10),
            "shearY": np.linspace(0, 0.3, 10),
            "translateX": np.linspace(0, 150 / 331, 10),
            "translateY": np.linspace(0, 150 / 331, 10),
            "rotate": np.linspace(0, 30, 10),
            "color": np.linspace(0.0, 0.9, 10),
            "posterize": np.round(np.linspace(8, 4, 10), 0).astype(int),
            "solarize": np.linspace(256, 0, 10),
            "contrast": np.linspace(0.0, 0.9, 10),
            "sharpness": np.linspace(0.0, 0.9, 10),
            "brightness": np.linspace(0.0, 0.9, 10),
            "autocontrast": [0] * 10,
            "equalize": [0] * 10,
            "invert": [0] * 10,
        }

        def rotate_with_fill(img, magnitude):
            rot = img.convert("RGBA").rotate(magnitude)
            return Image.composite(
                rot, Image.new("RGBA", rot.size, (128,) * 4), rot
            ).convert(img.mode)

        func = {
            "shearX": lambda img, magnitude: img.transform(
                img.size, Image.AFFINE,
                (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
                Image.BICUBIC, fillcolor=fillcolor),
            "shearY": lambda img, magnitude: img.transform(
                img.size, Image.AFFINE,
                (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
                Image.BICUBIC, fillcolor=fillcolor),
            "translateX": lambda img, magnitude: img.transform(
                img.size, Image.AFFINE,
                (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
                fillcolor=fillcolor),
            "translateY": lambda img, magnitude: img.transform(
                img.size, Image.AFFINE,
                (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
                fillcolor=fillcolor),
            "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
            "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
                1 + magnitude * random.choice([-1, 1])),
            "posterize": lambda img, magnitude: ImageOps.posterize(img, int(magnitude)),
            "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
            "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
                1 + magnitude * random.choice([-1, 1])),
            "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
                1 + magnitude * random.choice([-1, 1])),
            "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
                1 + magnitude * random.choice([-1, 1])),
            "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
            "equalize": lambda img, magnitude: ImageOps.equalize(img),
            "invert": lambda img, magnitude: ImageOps.invert(img),
        }

        self.p1 = p1
        self.operation1 = func[operation1]
        self.magnitude1 = ranges[operation1][magnitude_idx1]
        self.p2 = p2
        self.operation2 = func[operation2]
        self.magnitude2 = ranges[operation2][magnitude_idx2]

    def __call__(self, img):
        if random.random() < self.p1:
            img = self.operation1(img, self.magnitude1)
        if random.random() < self.p2:
            img = self.operation2(img, self.magnitude2)
        return img


class CIFAR10Policy(object):
    def __init__(self, fillcolor=(128, 128, 128)):
        self.policies = [
            SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
            SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
            SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
            SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
            SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),

            SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
            SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
            SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
            SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
            SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),

            SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
            SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
            SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
            SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
            SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),

            SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
            SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
            SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
            SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
            SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),

            SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
            SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
            SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
            SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
            SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor),
        ]

    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        return "AutoAugment CIFAR10 Policy"

# =========================================================
#  Custom CIFAR-100-LT loader (uses raw train/test/meta)
# =========================================================

class CIFAR100LT(Dataset):
    cls_num = 100
    dataset_name = "CIFAR-100-LT"

    def __init__(self, root, phase, imbalance_ratio=1.0, imb_type="exp"):
        """
        root: folder containing 'train', 'test', 'meta'
        phase: 'train' or 'test'
        """
        self.root = root
        self.train = (phase == "train")

        data_file = "train" if self.train else "test"
        path = os.path.join(root, data_file)
        print(f"Loading CIFAR-100 from {path}")
        with open(path, "rb") as f:
            entry = pickle.load(f, encoding="latin1")

        self.data = entry["data"]
        self.data = self.data.reshape((-1, 3, 32, 32)).transpose((0, 2, 3, 1))
        self.targets = entry["fine_labels"]

        if self.train:
            self.img_num_per_cls = self.get_img_num_per_cls(
                self.cls_num, imb_type, imbalance_ratio
            )
            self.gen_imbalanced_data(self.img_num_per_cls)

            # Augmentation as in EQL/BALMS: crop + flip + AutoAug + Cutout + norm
            self.transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                CIFAR10Policy(),
                transforms.ToTensor(),
                Cutout(n_holes=1, length=16),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465),
                    (0.2023, 0.1994, 0.2010),
                ),
            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465),
                    (0.2023, 0.1994, 0.2010),
                ),
            ])

        self.labels = self.targets
        print(f"{phase} Mode: Contain {len(self.data)} images")

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        gamma = 1.0 / imb_factor
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == "exp":
            for cls_idx in range(cls_num):
                num = img_max * (gamma ** (cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == "step":
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * gamma))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)

        os.makedirs("cls_freq", exist_ok=True)
        freq_path = os.path.join(
            "cls_freq", self.dataset_name + "_IMBA{}.json".format(imb_factor)
        )
        with open(freq_path, "w") as fd:
            json.dump(img_num_per_cls, fd)

        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)

        self.num_per_cls_dict = {}
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([the_class] * the_img_num)
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets

    def __getitem__(self, index):
        img, label = self.data[index], self.labels[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, label, index

    def __len__(self):
        return len(self.labels)

    def get_num_classes(self):
        return self.cls_num

    def get_cls_num_list(self):
        count = Counter(self.labels)
        return [count[i] for i in range(self.cls_num)]

# =========================================================
#  CIFAR ResNet-32 (feature backbone)
# =========================================================

def _weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

class BasicBlockCifar(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, option="A"):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == "A":
                self.shortcut = LambdaLayer(
                    lambda x: F.pad(
                        x[:, :, ::2, ::2],
                        (0, 0, 0, 0, planes // 4, planes // 4),
                        "constant", 0,
                    )
                )
            elif option == "B":
                self.shortcut = nn.Sequential(
                    nn.Conv2d(
                        in_planes,
                        self.expansion * planes,
                        kernel_size=1,
                        stride=stride,
                        bias=False,
                    ),
                    nn.BatchNorm2d(self.expansion * planes),
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out, inplace=True)
        return out

class ResNet_Cifar(nn.Module):
    def __init__(self, block, num_blocks):
        super().__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(
            3, 16, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)

        self.apply(_weights_init)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, **kwargs):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        feature_maps = out
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)   # [B, 64]
        return out, feature_maps

class DotProductClassifier(nn.Module):
    def __init__(self, feat_dim=64, num_classes=100):
        super().__init__()
        self.fc = nn.Linear(feat_dim, num_classes)
    def forward(self, x):
        return self.fc(x)

# =========================================================
#  Balanced Softmax + (2.7 - f_c)^(0.5 - p_t) weighting
#   logits'_k = z_k + log(n_k)
#   p = softmax(logits')
#   L_i = (2.7 - f_c)^(0.5 - p_t) * (-log p_t)
# =========================================================

class BalancedSoftmaxExpWeightedLoss(nn.Module):
    """
    Balanced Softmax with per-sample weight:

        logits'_k = z_k + log(n_k)
        p = softmax(logits')
        p_t = p_{true}

        L_i = (2.7 - f_{c_i})^(0.5 - p_t) * (-log p_t)

    where:
        f_c = normalized class frequency for class c.
        2.7 ~ exp(1).
    """
    def __init__(self, class_counts, eps=1e-12, base=2.7):
        super().__init__()
        counts = np.array(class_counts, dtype=np.float32)
        counts = counts + eps           # avoid log(0)
        freqs = counts / counts.sum()   # f_c

        log_counts = np.log(counts)

        self.register_buffer("log_counts",
                             torch.tensor(log_counts, dtype=torch.float32))
        self.register_buffer("freqs",
                             torch.tensor(freqs, dtype=torch.float32))

        self.base = float(base)

    def set_state(self, it=None, lr=None):
        # no-op to keep training loop unchanged
        return

    def forward(self, logits, targets):
        # Balanced Softmax logits: z_k + log(n_k)
        balanced_logits = logits + self.log_counts.unsqueeze(0)   # [B, C]

        # log p and p under balanced softmax
        log_probs = F.log_softmax(balanced_logits, dim=1)         # [B, C]
        probs = log_probs.exp()                                   # [B, C]

        # p_t for each sample
        pt = probs.gather(1, targets.view(-1, 1)).squeeze(1)      # [B]

        # f_c for each sample
        freqs_t = self.freqs[targets]                             # [B]

        # (2.7 - f_c)^(0.5 - p_t)
        base_t = self.base - freqs_t                              # [B], > 0
        weight = torch.pow(base_t, 0.75 - pt)                      # [B]

        # CE: -log p_t
        ce_per_sample = -log_probs.gather(1, targets.view(-1, 1)).squeeze(1)  # [B]

        loss = (weight * ce_per_sample).mean()
        return loss

# =========================================================
#  Iter-based LR scheduler (cosine with warmup)
# =========================================================

class IterLRScheduler:
    def __init__(self, optimizer, lr_init, lr_base, warmup_iters, total_iters):
        self.optimizer = optimizer
        self.lr_init = lr_init
        self.lr_base = lr_base
        self.warmup_iters = max(1, warmup_iters)
        self.total_iters = max(1, total_iters)
        self.iter = 0

    def step(self):
        self.iter += 1
        t = self.iter
        if t <= self.warmup_iters:
            ratio = t / float(self.warmup_iters)
            lr = self.lr_init + (self.lr_base - self.lr_init) * ratio
        else:
            decay_iter = t - self.warmup_iters
            decay_total = max(1, self.total_iters - self.warmup_iters)
            cosine = 0.5 * (1.0 + math.cos(math.pi * decay_iter / decay_total))
            lr = self.lr_base * cosine

        for pg in self.optimizer.param_groups:
            pg["lr"] = lr

# =========================================================
#  Build datasets & loaders
# =========================================================

train_dataset = CIFAR100LT(
    root=DATA_ROOT,
    phase="train",
    imbalance_ratio=IMB_FACTOR,
    imb_type='exp',
)
test_dataset = CIFAR100LT(
    root=DATA_ROOT,
    phase="test",
    imbalance_ratio=1.0,
    imb_type='exp',   # unused for test, just for consistency
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    drop_last=True,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

class_counts = train_dataset.get_cls_num_list()
print("Class counts (first 10):", class_counts[:10])

# =========================================================
#  Model, loss, optimizer, scheduler
# =========================================================

backbone = ResNet_Cifar(BasicBlockCifar, [5, 5, 5]).to(device)  # ResNet-32
classifier = DotProductClassifier(feat_dim=64, num_classes=100).to(device)

criterion = BalancedSoftmaxExpWeightedLoss(
    class_counts,
    eps=1e-12,
    base=2.718,   # ~ exp(1)
).to(device)

optimizer = torch.optim.SGD(
    list(backbone.parameters()) + list(classifier.parameters()),
    lr=BASE_LR,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY,
    nesterov=True,
)

scheduler = IterLRScheduler(
    optimizer,
    lr_init=BASE_LR,
    lr_base=WARMUP_LR,
    warmup_iters=WARMUP_ITERS,
    total_iters=TOTAL_ITERS,
)

# =========================================================
#  Eval helper
# =========================================================

@torch.no_grad()
def evaluate():
    backbone.eval()
    classifier.eval()
    ce = nn.CrossEntropyLoss(reduction="sum").to(device)
    total_loss = 0.0
    total_correct = 0
    total = 0
    for x, y, idx in test_loader:
        x, y = x.to(device), y.to(device)
        feats, _ = backbone(x)
        logits = classifier(feats)
        loss = ce(logits, y)
        total_loss += loss.item()
        preds = logits.argmax(1)
        total_correct += (preds == y).sum().item()
        total += y.size(0)
    return total_loss / total, total_correct / total

# =========================================================
#  Training loop (13k iterations, single stage)
# =========================================================

best_acc = 0.0
it = 0

print(f"Starting single-stage training: TOTAL_ITERS={TOTAL_ITERS}, "
      f"WARMUP_ITERS={WARMUP_ITERS}, IMB_FACTOR={IMB_FACTOR}")

while it < TOTAL_ITERS:
    for x, y, idx in train_loader:
        it += 1
        if it > TOTAL_ITERS:
            break

        x, y = x.to(device), y.to(device)

        backbone.train()
        classifier.train()
        scheduler.step()

        # get current LR and update loss state (no-op here, but kept for compatibility)
        current_lr = optimizer.param_groups[0]["lr"]
        criterion.set_state(it, current_lr)

        feats, _ = backbone(x)
        logits = classifier(feats)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if it % 100 == 0:
            lr = optimizer.param_groups[0]["lr"]
            with torch.no_grad():
                batch_acc = (logits.argmax(1) == y).float().mean().item()
            print(
                f"[Iter {it}/{TOTAL_ITERS}] "
                f"loss={loss.item():.4f}, "
                f"train_acc={batch_acc*100:.2f}%, "
                f"lr={lr:.5f}"
            )

        if it % 1000 == 0 or it == TOTAL_ITERS:
            val_loss, val_acc = evaluate()
            best_acc = max(best_acc, val_acc)
            print(
                f"  Eval @ iter {it}: "
                f"val_loss={val_loss:.4f}, "
                f"val_acc={val_acc*100:.2f}%, "
                f"best={best_acc*100:.2f}%"
            )

print(f"Training complete. Best test accuracy: {best_acc*100:.2f}%")


Using device: cuda
Loading CIFAR-100 from /kaggle/input/hihihi/train
train Mode: Contain 12608 images
Loading CIFAR-100 from /kaggle/input/hihihi/test
test Mode: Contain 10000 images
Class counts (first 10): [500, 480, 462, 444, 426, 410, 394, 379, 364, 350]
Starting single-stage training: TOTAL_ITERS=13000, WARMUP_ITERS=800, IMB_FACTOR=50.0
[Iter 100/13000] loss=3.8855, train_acc=9.57%, lr=0.05625
[Iter 200/13000] loss=3.4933, train_acc=16.99%, lr=0.06250
[Iter 300/13000] loss=3.2418, train_acc=21.29%, lr=0.06875
[Iter 400/13000] loss=3.2090, train_acc=23.05%, lr=0.07500
[Iter 500/13000] loss=3.0154, train_acc=28.32%, lr=0.08125
[Iter 600/13000] loss=2.6500, train_acc=32.42%, lr=0.08750
[Iter 700/13000] loss=2.4229, train_acc=38.48%, lr=0.09375
[Iter 800/13000] loss=2.5885, train_acc=33.98%, lr=0.10000
[Iter 900/13000] loss=2.1899, train_acc=41.21%, lr=0.09998
[Iter 1000/13000] loss=2.2780, train_acc=43.75%, lr=0.09993
  Eval @ iter 1000: val_loss=3.7443, val_acc=23.44%, best=23.44%
[