<a href="https://colab.research.google.com/github/guiOsorio/Learning_JAX/blob/master/CC_Comet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- Focal loss: https://www.youtube.com/watch?v=Y8_OVwK4ECk
- Focal loss for PyTorch: https://github.com/AdeelH/pytorch-multi-class-focal-loss

In [56]:
# Install Comet
!pip install comet_ml --quiet
# Install Flax and JAX
!pip install --upgrade -q "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install --upgrade -q git+https://github.com/google/flax.git

In [57]:
import jax
from jax import lax, random, jit, numpy as jnp

import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
from flax.training import train_state

import optax

import torch
from torch.utils.data import Dataset, DataLoader

import functools
from typing import Sequence, Callable, Any, Optional

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [58]:
# Load data
url = 'https://raw.githubusercontent.com/nsethi31/Kaggle-Data-Credit-Card-Fraud-Detection/master/creditcard.csv'
df = pd.read_csv(url)
df.head()

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
0,0.0,-1.359807,-0.072781,2.536347,1.378155,-0.338321,0.462388,0.239599,0.098698,0.363787,...,-0.018307,0.277838,-0.110474,0.066928,0.128539,-0.189115,0.133558,-0.021053,149.62,0
1,0.0,1.191857,0.266151,0.16648,0.448154,0.060018,-0.082361,-0.078803,0.085102,-0.255425,...,-0.225775,-0.638672,0.101288,-0.339846,0.16717,0.125895,-0.008983,0.014724,2.69,0
2,1.0,-1.358354,-1.340163,1.773209,0.37978,-0.503198,1.800499,0.791461,0.247676,-1.514654,...,0.247998,0.771679,0.909412,-0.689281,-0.327642,-0.139097,-0.055353,-0.059752,378.66,0
3,1.0,-0.966272,-0.185226,1.792993,-0.863291,-0.010309,1.247203,0.237609,0.377436,-1.387024,...,-0.1083,0.005274,-0.190321,-1.175575,0.647376,-0.221929,0.062723,0.061458,123.5,0
4,2.0,-1.158233,0.877737,1.548718,0.403034,-0.407193,0.095921,0.592941,-0.270533,0.817739,...,-0.009431,0.798278,-0.137458,0.141267,-0.20601,0.502292,0.219422,0.215153,69.99,0


In [59]:
class CustomTensorDataset(Dataset):
  def __init__(self, dataset):
    [data_X, data_y] = dataset
    X_tensor, y_tensor = data_X, data_y
    tensors = (X_tensor, y_tensor)
    assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
    self.tensors = tensors
    self.data = tensors[0]
    self.targets = tensors[1]

  def __getitem__(self, index):
    x = self.tensors[0][index]

    y = self.tensors[1][index]

    return x, y

  def __len__(self):
    return self.tensors[0].size(0)

# Divide into features and labels
df_x = df.iloc[:, 1:4]
df_y = df['Class'].to_frame()

total_points = df_y.shape[0]
split = round(total_points*0.8)

# Convert pd.dataframes to tensors
train_x = torch.tensor(df_x.values, dtype=torch.float32)[:split]
train_y = torch.squeeze(torch.tensor(df_y.values, dtype=torch.float32)[:split])

test_x = torch.tensor(df_x.values, dtype=torch.float32)[split:]
test_y = torch.squeeze(torch.tensor(df_y.values, dtype=torch.float32)[split:])

train_x.size(), train_y.size(), test_x.size(), test_y.size()

(torch.Size([227846, 3]),
 torch.Size([227846]),
 torch.Size([56961, 3]),
 torch.Size([56961]))

In [60]:
# Transform tensors to np arrays in dataloaders, tensors not compatible with JAX
def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))

    labels = np.stack(transposed_data[1])
    features = np.stack(transposed_data[0])

    return features, labels

In [61]:
# Implementation with batch norm and dropout
class NN_regularized(nn.Module):

  @nn.compact 
  def __call__(self, x, train: bool):
    # Linear + dropout + relu
    x = nn.Dense(features=100)(x)
    x = nn.Dropout(0.2, deterministic=not train)(x)
    x = nn.relu(x)

    # Linear + batch norm + relu
    x = nn.Dense(features=256)(x)
    x = nn.BatchNorm(use_running_average=not train)(x)
    x = nn.relu(x)

    # Linear + softmax
    x = nn.Dense(features=2)(x)
    x = nn.log_softmax(x)
    return x

In [62]:
# Replicate each TP in the training set n times
def mod_data(n): # higher n => higher recall, lower precision?
  TP_idxs = (train_y == 1).nonzero(as_tuple=True)[0]

  extra_xs = train_x[TP_idxs].repeat(n, 1)
  extra_ys = train_y[TP_idxs].repeat(n,)

  train_x_mod = torch.cat((train_x, extra_xs), 0)
  train_y_mod = torch.cat((train_y, extra_ys), 0)

  input_size = (1, 3)
  batch_size = 128

  train_mod = [train_x_mod, train_y_mod]
  train_dset_mod = CustomTensorDataset(train_mod)
  train_loader_mod = DataLoader(train_dset_mod, collate_fn=custom_collate_fn, batch_size=batch_size, shuffle=True)

  test = [test_x, test_y]
  test_dset = CustomTensorDataset(test)
  test_loader = DataLoader(test_dset, collate_fn=custom_collate_fn, batch_size=batch_size, shuffle=True)

  # optimization - loading the whole dataset into memory
  train_features = jnp.array(train_dset_mod.data)
  train_lbls = jnp.array(train_dset_mod.targets)

  # np.expand_dims is to convert shape from (10000, 28, 28) -> (10000, 28, 28, 1)
  # We don't have to do this for training images because custom_transform does it for us.
  test_features = jnp.array(test_dset.data)
  test_lbls = jnp.array(test_dset.targets)

  return train_loader_mod, test_features, test_lbls

## Create test loader
tlm_test, tf_test, tl_test = mod_data(10)
for data in tlm_test:
  x, y = data
  print(x.shape)
  print(y.shape)
  break

print(tf_test.shape)
print(tl_test.shape)

(128, 3)
(128,)
(56961, 3)
(56961,)


In [63]:
# TRAINING

# Compute loss and update - this will be computed many times, so it's best to jit it
@jit
def training_state(state, imgs, gt_labels):

  def FocalLoss(params, batch_stats):
    logits, updates = NN_regularized().apply({'params': params, 'batch_stats': batch_stats}, imgs, train=True, rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
    # logits is a vector of probabilities predicted by the model (the highest value in the vector is the prediction)
    one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=2) # one hot encoded vector of labels
    # logits.shape and one_hot_gt_labels shape is (batch_size, num_classes)
    loss = -jnp.mean(jnp.sum(logits * one_hot_gt_labels, axis=-1)) # axis=-1 means sum over rows ||-> CE = true probability (one hot gt labels) * predicted probability (logits)

    #### TODO 
    # Cross entropy to focal loss -> -log(pt) TO -log(pt) * (1-pt)^(gamma)
    # above is -log(pt), need to find (1-pt)^gamma. log10(x) = 2 ==> x = 10^2, therefore log10(x) = logit ==> x = 10^logit
    gamma = 2
    probs_t = jnp.power(10, logits)
    rev_probs_t = jnp.power(1-probs_t, gamma)
    ####

    # Add l2 regularization
    alpha = 0.1
    def l2_loss(weights, alpha):
      return alpha * (weights ** 2).mean()
    
    loss += sum(
        l2_loss(w, alpha)
        for w in jax.tree_util.tree_leaves(params)
      )

    return loss, (logits, updates)
  
  (loss, (logits, updates)), grads = jax.value_and_grad(FocalLoss, argnums=0, has_aux=True)(state.params, state.batch_stats)
  state = state.apply_gradients(grads=grads) # update state params based on grads calculated
  state = state.replace(batch_stats=updates['batch_stats']) # update state batch_stats variables

  ## Accuracy
  accuracy = jnp.mean(jnp.argmax(logits, -1) == gt_labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }

  return state, metrics

# One epoch - need to add metrics part
def train_one_epoch(state, dataloader):
  batch_metrics = []
  for cnt, (imgs, labels) in enumerate(dataloader):
    state, metrics = training_state(state, imgs, labels)
    batch_metrics.append(metrics)

  batch_metrics_np = jax.device_get(batch_metrics)  # pull from the accelerator onto host (CPU)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]
  }

  return state, epoch_metrics_np

def create_train_state(key, lr, momentum):
  # Create model
  NN = NN_regularized()
  # Initialize parameters
  variables = NN.init(key, jnp.ones([1, *input_size]), train=False)
  params = variables['params']
  batch_stats_v = variables['batch_stats']
  del variables

  class TrainState_stats(train_state.TrainState):
    batch_stats: Any

  state = TrainState_stats.create(
    apply_fn=NN.apply,
    params=params,
    batch_stats=batch_stats_v,
    tx=optax.sgd(lr, momentum)
  )

  return state

In [74]:
# EVALUATION

# Run one evaluation on test set
@jit
def eval_step(state, imgs, gt_labels):
  logits = NN_regularized().apply({'params': state.params, 'batch_stats': state.batch_stats}, imgs, rngs={'dropout': jax.random.PRNGKey(0)}, train=False)
  one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=2)
  loss = -jnp.mean(jnp.sum(logits * one_hot_gt_labels, axis=-1))
  preds = jnp.argmax(logits, -1)
  accuracy = jnp.mean(preds == gt_labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }
  return metrics, preds

def evaluate_model(state, test_imgs, test_labels):
  metrics, preds = eval_step(state, test_imgs, test_labels)
  metrics = jax.device_get(metrics) # pull from accelerator to CPU
  metrics = jax.tree_map(lambda x: x.item(), metrics) # get scalar value from array
  return metrics, preds

In [75]:
# FIT
from sklearn.metrics import precision_score, recall_score
from comet_ml import Experiment

for n in range(1, 20, 2):
  # Create experiment
  experiment = Experiment(
    api_key="tSIIIrf40FyA0qDCJeLLC5jZP",
    project_name="credit-card-fraud",
    workspace="guiosorio",
  )
  experiment.set_name(f'{n} Repeats')

  from flax.training import train_state
  seed = 0
  lr = 0.01 # lower learning rate with batch norm
  momentum = 0.9
  n_epochs = 4

  train_state = create_train_state(jax.random.PRNGKey(seed), lr, momentum)

  train_loader_mod, test_features, test_lbls = mod_data(n)

  for epoch in range(n_epochs):
    train_state, train_metrics = train_one_epoch(train_state, train_loader_mod)

    test_metrics, test_preds = evaluate_model(train_state, test_features, test_lbls)

    # Precision
    precision = precision_score(test_lbls, test_preds)
    test_metrics['precision'] = precision
    # Recall
    recall = recall_score(test_lbls, test_preds)
    test_metrics['recall'] = recall

    # Log metrics
    experiment.log_metrics(test_metrics, step=epoch)

  # Log confusion matrix
  experiment.log_confusion_matrix(np.array(test_lbls).astype(int), np.array(test_preds).astype(int))
  # Log parameter
  experiment.log_parameter('Repeats', n)

  experiment.end()

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/guiosorio/credit-card-fraud/d3526321016240efb5b533517d025c1e
COMET INFO:   Others:
COMET INFO:     Name : 1 Repeats
COMET INFO:   Uploads:
COMET INFO:     environment details : 1
COMET INFO:     filename            : 1
COMET INFO:     installed packages  : 1
COMET INFO:     notebook            : 2
COMET INFO:     os packages         : 1
COMET INFO:     source_code         : 1
COMET INFO: ---------------------------
COMET ERROR: Failed to calculate active processors count. Fall back to default CPU count 1
COMET INFO: Couldn't find a Git repository in '/content' nor in any parent directory. You can override where Comet is looking for a Git Patch by setting the configuration `COMET_GIT_DIRECTORY`
COMET INFO: Experiment is live on comet.com https://w

In [None]:
test_lbls

DeviceArray([0., 0., 0., ..., 0., 0., 0.], dtype=float32)