# Training Models

In [1]:
from config import *
from tqdm.notebook import tqdm
import utils

utils.set_device()
utils.set_seed()

Using device = cpu


## Datasets and Dataloaders

In [2]:
from data.dataset import LeakAnomalyDetectionDataset
from torch.utils.data import DataLoader, Subset, random_split



In [3]:
leaks_dataset = LeakAnomalyDetectionDataset(normal_data_dir=NORMAL_DATA, anomalous_data_dir=ANOMALOUS_DATA)

train_size = int(TRAIN_SIZE * len(leaks_dataset))
val_size = int(VAL_SIZE * len(leaks_dataset))
test_size = len(leaks_dataset) - train_size - val_size
train_set, val_set, test_set = random_split(leaks_dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

# print(summary(model, input_size=train_loader.dataset[0][0].shape))


(40, 220, 3)


## Models

In [4]:
import torch
from models.classifiers  import RNNClassfier, CNNRNNClassifier
from models.autoencoders import SimpleAutoencoder, ScheduledSamplingAutoencoder

### Classification Models

In [11]:
model = RNNClassfier(3, 16)
model.configure_optimizers()


### Autoencoder Models

In [None]:
model = SimpleAutoencoder(3, 16)

## Training

In [None]:

best_val_loss = 1_000_000

for epoch in range(100):
  with tqdm(total=len(train_loader), desc=f'Epoch {epoch + 1}/{EPOCHS}', unit='batch') as pbar:

    # Metrics
    metrics = {
      'loss': 0.0,
      'val_loss': 0.0
    }
    
    model.train()
    for batch in train_loader:
        model.optimizer.zero_grad()
        loss = model.training_step(batch)
        loss.backward()
        model.optimizer.step()
        metrics['loss'] += loss.item()
        pbar.update()
        
    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            loss = model.validation_step(batch)
            metrics['val_loss'] += loss.item()
    
    model.lr_scheduler.step(metrics['val_loss'])
    
    # Logging
    metrics['loss'] /= len(train_loader)
    metrics['val_loss'] /= len(val_loader)
    pbar.set_postfix(metrics)
      
    
    # Perform checkpointing here

In [15]:
from sklearn.metrics import accuracy_score, det_curve
import matplotlib.pyplot as plt

In [16]:
with torch.no_grad():
  for x, y_true in test_loader:
    y_scores = model.forward(x)
    print(y_scores, y_true)
    fpr, fnr, thresholds = det_curve(y_true, y_scores)
    print(fpr, fnr)

tensor([9.9999e-01, 9.9999e-01, 4.6847e-05, 6.7276e-05]) tensor([1., 1., 0., 0.])
[0.] [0.]
