In [25]:
import os
import random
import torch
import cv2
from collections import Counter
from pathlib import Path
from tqdm import tqdm
import numpy as np
from scipy import stats
import torchvision as tv
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    EarlyStopping,
)
from torch.utils.data import DataLoader, random_split, Dataset, WeightedRandomSampler
from torchvision.datasets import ImageFolder
from PIL import Image

from src.utils.helpers import load_config
from src.training.dataset import ImageDataModule
from src.models.classification_model import ImageClassifier

In [26]:
real_filepaths = os.listdir("src/data/chameleon/real")
fake_filepaths = os.listdir("src/data/chameleon/fake")
real_filepaths = [os.path.join("src/data/chameleon/real", fp) for fp in real_filepaths]
fake_filepaths = [os.path.join("src/data/chameleon/fake", fp) for fp in fake_filepaths]

random_real_filepaths = np.random.choice(real_filepaths, size=2000, replace=False)
random_fake_filepaths = np.random.choice(fake_filepaths, size=2000, replace=False)

real_count = len(real_filepaths)
fake_count = len(fake_filepaths)
print(f"REAL COUNT: {real_count} ({round(real_count/(fake_count+real_count) * 100, 1)}%)")
print(f"FAKE COUNT: {fake_count} ({round(fake_count/(fake_count+real_count) * 100, 1)}%)")

In [28]:
config = load_config("src/configs/example.yaml")
csv_logger = CSVLogger(config["logging"].get("logs_dir", "logs"))
callbacks = []

pl.seed_everything(config.get("seed"))

Seed set to 7643


7643

In [29]:
# Prepare data:
data_module = ImageDataModule(config)

In [30]:
# Enable learning rate monitoring hook:
if config["training"].get("lr_monitoring"):
    monitor_lr = LearningRateMonitor(logging_interval="epoch")
    callbacks.append(monitor_lr)

# Enable early stopping hook:
estop = config["training"].get("early_stopping", {})
if estop.get("enabled"):
    early_stopping = EarlyStopping(
        monitor=estop.get("monitor", "val_loss"),
        patience=estop.get("patience", 5),
        mode=estop.get("mode", "min"),
        verbose=True,
    )
    callbacks.append(early_stopping)

# Enable model checkpointing hook:
checkpoint = ModelCheckpoint(
    dirpath=config["logging"].get("checkpoint_dir", "checkpoints"),
    filename="{epoch:02d}-{val_acc:.2f}",
    save_top_k=1,
    monitor="val_acc",
    mode="max",
    save_last=True,
)
callbacks.append(checkpoint)

In [31]:
# Initialize model:
net = ImageClassifier(config)

# Configure training session:
trainer = pl.Trainer(
    min_epochs=config["training"].get("min_epochs"),
    max_epochs=config["training"].get("max_epochs", -1),
    logger=csv_logger,
    enable_checkpointing=True,
    enable_progress_bar=True,
    enable_model_summary=True,
    devices="auto",
    accelerator="auto",
    log_every_n_steps=5,
    callbacks=callbacks,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [32]:
# Begin training:
trainer.fit(
    model=net,
    datamodule=data_module,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | loss_func | CrossEntropyLoss | 0      | train
1 | backbone  | ConvNeXt         | 27.8 M | train
-------------------------------------------------------
27.8 M    Trainable params
0         Non-trainable params
27.8 M    Total params
111.287   Total estimated model params size (MB)
250       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 50, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataset.py", line 416, in __getitems__
    return [self.dataset[self.indices[idx]] for idx in indices]
            ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/teamspace/studios/this_studio/Reveal-AI/src/training/dataset.py", line 31, in __getitem__
    aug = self.transforms(image=img_arr)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
    img = t(img)
          ^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/albumentations/core/transforms_interface.py", line 246, in __call__
    raise KeyError(msg)
KeyError: 'You have to pass data to augmentations as named arguments, for example: aug(image=image)'


In [None]:
# Test best model:
trainer.test(
    model=net,
    datamodule=data_module,
    ckpt_path="src/checkpoints/best",
)