# Setup

In [1]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

Tue Nov  8 14:54:31 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Quadro P...  On   | 00000000:00:05.0  On |                  N/A |
| 46%   29C    P8     6W / 105W |    843MiB /  8119MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops

import json
import pickle

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *
from utils.metrics import *


In [3]:
if torch.cuda.is_available:
  print('Good to go!')
else:
  print('Training might be rather slow')

Good to go!


## Model Training


In [4]:
train = True
save_metrics = False

task_dir = "1L_MLP_sym_S5"
seed, frac_train, layers, lr, group_param, weight_decay, num_epochs, group_type, architecture_type, metrics = load_cfg(task_dir)
group = group_type(group_param)

if train:
    train_data, test_data, train_labels, test_labels = generate_train_test_data(group, frac_train, seed)

    train_losses = []
    test_losses = []
    train_accs = []
    test_accs = []

    model = architecture_type(layers, group.order, seed)
    model.cuda()
    metrics = Metrics(model, group, metrics)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    for epoch in tqdm(range(num_epochs)):
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_losses.append(train_loss.item())
        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())
            train_acc = (train_logits.argmax(1)==train_labels).sum()/len(train_labels)
            test_acc = (test_logits.argmax(1)==test_labels).sum()/len(test_labels)
            train_accs.append(train_acc.item())
            test_accs.append(test_acc.item())

            if save_metrics and epoch % 100 == 0:
                metrics.update_model(model)
                metrics.get_metrics()

        if epoch%1000 == 0:
            print(f"Epoch:{epoch}, Train: L: {train_losses[-1]:.6f} A: {train_accs[-1]*100:.4f}%, Test: L: {test_losses[-1]:.6f} A: {test_accs[-1]*100:.4f}%")
        #if epoch%50000 == 0 and epoch>0:
            #lines([train_losses, test_losses], log_y=True, labels=['train loss', 'test loss'])
            #lines([train_accs, test_accs], log_y=False, labels=['train acc', 'test acc'])



  0%|          | 0/50000 [00:00<?, ?it/s]

Epoch:0, Train: L: 4.787771 A: 0.6806%, Test: L: 4.790697 A: 0.6806%
Epoch:1000, Train: L: 0.038816 A: 100.0000%, Test: L: 16.920149 A: 0.1667%
Epoch:2000, Train: L: 0.016403 A: 100.0000%, Test: L: 15.663980 A: 0.4028%
Epoch:3000, Train: L: 0.008188 A: 100.0000%, Test: L: 14.092774 A: 0.9444%
Epoch:4000, Train: L: 0.004464 A: 100.0000%, Test: L: 12.808044 A: 1.9167%
Epoch:5000, Train: L: 0.002524 A: 100.0000%, Test: L: 11.757225 A: 2.8056%
Epoch:6000, Train: L: 0.001466 A: 100.0000%, Test: L: 10.817681 A: 3.9583%
Epoch:7000, Train: L: 0.000853 A: 100.0000%, Test: L: 9.717101 A: 5.9861%
Epoch:8000, Train: L: 0.000500 A: 100.0000%, Test: L: 8.483696 A: 9.1806%
Epoch:9000, Train: L: 0.000297 A: 100.0000%, Test: L: 7.302490 A: 13.4306%
Epoch:10000, Train: L: 0.000177 A: 100.0000%, Test: L: 6.245217 A: 18.2639%
Epoch:11000, Train: L: 0.000106 A: 100.0000%, Test: L: 5.176885 A: 25.0417%
Epoch:12000, Train: L: 0.000064 A: 100.0000%, Test: L: 4.011427 A: 33.8750%
Epoch:13000, Train: L: 0.00003

In [5]:
lines([train_losses, test_losses], log_y=True, labels=['train loss', 'test loss'], save=f"{task_dir}/loss.png")
lines([train_accs, test_accs], log_y=False, labels=['train acc', 'test acc'], save=f"{task_dir}/acc.png")
torch.save(model.state_dict(), f"{task_dir}/model.pt")
if save_metrics:
    with open(f'{task_dir}/metrics.pkl', 'wb') as f:
        pickle.dump(metrics.data, f)
