## Imports

In [1]:
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.optim import Adam
from pytorch_lightning.loggers import WandbLogger
import matplotlib.pyplot as plt
import numpy as np
import os

## Model

In [2]:
class SimpleCNN(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.lr = lr
        
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        # For 224x224 input: after two pooling layers -> 56x56
        self.fc1 = nn.Linear(32 * 56 * 56, 128)
        self.fc2 = nn.Linear(128, 1)
        
        self.dropout = nn.Dropout(0.3)
    
        self.train_losses = []
        self.train_accs = []
        self.val_losses = []
        self.val_accs = []
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) 
        x = self.pool(F.relu(self.conv2(x)))
        
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        images, labels = batch
        labels = labels.float().unsqueeze(1)
        outputs = self(images)
        loss = F.binary_cross_entropy_with_logits(outputs, labels)
        
        acc = ((outputs.sigmoid() > 0.5) == labels).float().mean()
        
        self.log("train/loss", loss, prog_bar=True)
        self.log("train/acc", acc, prog_bar=True)
        
        self.train_losses.append(loss)
        self.train_accs.append(acc)
        
        return loss
    
    def on_train_epoch_end(self):
        avg_loss = torch.stack(self.train_losses).mean()
        avg_acc = torch.stack(self.train_accs).mean()
        
        self.log("train/epoch_loss", avg_loss, prog_bar=True)
        self.log("train/epoch_acc", avg_acc, prog_bar=True)
        self.log("step", self.current_epoch)
        
        self.train_losses = []
        self.train_accs = []

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        labels = labels.float().unsqueeze(1)
        outputs = self(images)
        loss = F.binary_cross_entropy_with_logits(outputs, labels)
        acc = ((outputs.sigmoid() > 0.5) == labels).float().mean()
        
        self.log("val/loss", loss, prog_bar=True)
        self.log("val/acc", acc, prog_bar=True)
        
        self.val_losses.append(loss)
        self.val_accs.append(acc)
        
        return loss
    
    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.val_losses).mean()
        avg_acc = torch.stack(self.val_accs).mean()
        
        self.log("val/epoch_loss", avg_loss, prog_bar=True)
        self.log("val/epoch_acc", avg_acc, prog_bar=True)
        self.log("step", self.current_epoch)
        
        self.print(f"Epoch {self.current_epoch}: Val Loss: {avg_loss:.4f}, Val Acc: {avg_acc:.4f}")
        
        self.val_losses = []
        self.val_accs = []

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.lr)

## Dataset

In [3]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.data = ImageFolder(root=root_dir, transform=transform)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img, label = self.data[idx]
        return img, label

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

train_data = ImageDataset("data/train", transform=transform)
val_data = ImageDataset("data/val", transform=transform)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=0)
val_loader = DataLoader(val_data, batch_size=128, shuffle=False, num_workers=0)

## Training

In [4]:
model = SimpleCNN(lr=1e-3)

wandb_logger = WandbLogger(project="dead-leaves-binary-membership-classifier", name="SimpleCNN + lr1e-3 + Balanced Data + 20epochs")

trainer = pl.Trainer(
    max_epochs=20,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    logger=wandb_logger
)

trainer.fit(model, train_loader, val_loader)

trainer.save_checkpoint('model.ckpt')

wandb.finish()

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/glibesyck/miniconda3/envs/dead-leaves/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mgsolodzhuk[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name    | Type      | Params | Mode 
----------------------------------------------
0 | conv1   | Conv2d    | 160    | train
1 | conv2   | Conv2d    | 4.6 K  | train
2 | pool    | MaxPool2d | 0      | train
3 | fc1     | Linear    | 12.8 M | train
4 | fc2     | Linear    | 129    | train
5 | dropout | Dropout   | 0      | train
----------------------------------------------
12.9 M    Trainable params
0         Non-trainable params
12.9 M    Total params
51.400    Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Users/glibesyck/miniconda3/envs/dead-leaves/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0: Val Loss: 0.6583, Val Acc: 1.0000


/Users/glibesyck/miniconda3/envs/dead-leaves/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/glibesyck/miniconda3/envs/dead-leaves/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (11) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 0: Val Loss: 0.7277, Val Acc: 0.4453


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 1: Val Loss: 0.6780, Val Acc: 0.5612


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 2: Val Loss: 0.6740, Val Acc: 0.5861


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 3: Val Loss: 0.6672, Val Acc: 0.5849


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 4: Val Loss: 0.6671, Val Acc: 0.6030


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 5: Val Loss: 0.6465, Val Acc: 0.6218


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 6: Val Loss: 0.6389, Val Acc: 0.6531


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 7: Val Loss: 0.6365, Val Acc: 0.6432


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 8: Val Loss: 0.6150, Val Acc: 0.6713


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 9: Val Loss: 0.5811, Val Acc: 0.6719


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 10: Val Loss: 0.5647, Val Acc: 0.6894


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 11: Val Loss: 0.5389, Val Acc: 0.6980


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 12: Val Loss: 0.5294, Val Acc: 0.7219


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 13: Val Loss: 0.5255, Val Acc: 0.7115


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 14: Val Loss: 0.5180, Val Acc: 0.7326


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 15: Val Loss: 0.5089, Val Acc: 0.7492


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 16: Val Loss: 0.4944, Val Acc: 0.7563


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 17: Val Loss: 0.5113, Val Acc: 0.7526


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 18: Val Loss: 0.5161, Val Acc: 0.7512


Validation: |                                             | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: Val Loss: 0.5257, Val Acc: 0.7604


0,1
epoch,▁▃▅█
train/acc,▁▄▇█
train/epoch_acc,▁▁▂▂▂▃▃▃▄▄▅▅▆▆▇▇▇▇██
train/epoch_loss,█▄▄▄▄▄▄▄▄▃▃▃▃▂▂▂▂▁▁▁
train/loss,█▆▂▁
trainer/global_step,▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▆▁▁▁▂▂▂▂▂▂█▂▂▂
val/acc,▁▁▃▃▄▄▅▅▆▅▆▆▆▇▇▇████
val/epoch_acc,▁▄▄▄▅▅▆▅▆▆▆▇▇▇▇█████
val/epoch_loss,█▇▆▆▆▆▅▅▅▄▃▂▂▂▂▁▁▂▂▂
val/loss,█▇▇▇▇▆▆▅▅▄▃▂▂▁▂▂▁▁▂▁

0,1
epoch,18.0
train/acc,0.92188
train/epoch_acc,0.92822
train/epoch_loss,0.20921
train/loss,0.25096
trainer/global_step,19.0
val/acc,0.7617
val/epoch_acc,0.76042
val/epoch_loss,0.52567
val/loss,0.51877
