<a href="https://colab.research.google.com/github/chengyang122/mutitaskNAM/blob/main/MutiTaskClassificationTutoral.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/chengyang122/mutitaskNAM.git

fatal: destination path 'mutitaskNAM' already exists and is not an empty directory.


In [2]:
cd mutitaskNAM

/content/mutitaskNAM


In [3]:
import os
import tqdm
import copy
import random
import logging
from absl import app
from absl import flags
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
from nam.metrics import *
import nam.data_utils
from nam.model import *

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
features = pd.read_csv('features.csv', index_col=0)
target = pd.read_csv('target.csv', index_col=0)
s = target['0']
oneHotTarget = pd.get_dummies(s)
x_train = features.to_numpy()[:400]
y_train = oneHotTarget.to_numpy()[:400]
x_val = features.to_numpy()[400:]
y_val = oneHotTarget.to_numpy()[400:]

In [28]:
y_train.shape

(400, 200)

In [26]:
y_train[1].shape

(200,)

In [5]:
model = NeuralAdditiveModel(
        input_size=x_train.shape[-1],
        # feature size, 0 is sample and 1 is the feature, this is one iter of torch dataloader
        output_size=1 if len(y_train.shape)==1 else y_train.shape[-1],
        shallow_units=nam.data_utils.calculate_n_units(x_train, 1000, 2),
        # for feature network, it is changing with data and I am not sure why
        hidden_units=list(map(int, [])),  # for feature network
        shallow_layer=ExULayer,  # special operational layer designed for this model
        hidden_layer=ExULayer,
        hidden_dropout=0.3,
        feature_dropout=0.0).to(device)

In [6]:
train_dataset = TensorDataset(torch.tensor(x_train), torch.tensor(y_train))
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
val_dataset = TensorDataset(torch.tensor(x_val), torch.tensor(y_val))
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=True)

In [7]:
regression = False
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=1e-3,
                              weight_decay=0.0)
criterion = penalized_cross_entropy_MutiTask
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.995, step_size=1)


In [8]:
def feature_loss(fnn_out, lambda_=0.):
    return lambda_ * (fnn_out ** 2).sum() / fnn_out.shape[1]

def penalized_cross_entropy(logits, truth, fnn_out, feature_penalty=0.):
    loss = torch.nn.CrossEntropyLoss()
    return loss(logits, truth.argmax(-1)) + feature_loss(fnn_out, feature_penalty)

In [9]:
def train_one_epoch(model, criterion, optimizer, data_loader, device):
    pbar = tqdm.tqdm(enumerate(data_loader, start=1), total=len(data_loader))
    total_loss = 0
    for i, (x, y) in pbar:
        x, y = x.to(device), y.to(device)
        logits, fnns_out = model.forward(x)
        loss = criterion(logits, y, fnns_out, feature_penalty=0.0)
        total_loss -= (total_loss / i) - (loss.item() / i)
        model.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.set_description(f"train | loss = {total_loss:.5f}")
    return total_loss

In [10]:
def evaluate(model, data_loader, device):
    total_score = 0
    metric = None
    for i, (x, y) in enumerate(data_loader, start=1):
        x, y = x.to(device), y.to(device)
        logits, fnns_out = model.forward(x)
        metric, score = calculate_metric(logits, y, regression=False)
        total_score -= (total_score / i) - (score / i)
    return metric, total_score

In [24]:
val_scores = []
best_validation_score, best_weights = 0, None
n_tries = 60
for epoch in range(300):
    model = model.train()
    total_loss = train_one_epoch(model, criterion, optimizer, train_loader, device)
    logging.info(f"epoch {epoch} | train | {total_loss}")

    scheduler.step()

    model = model.eval()
    metric, val_score = evaluate(model, val_loader, device)
    metric, train_score = evaluate(model, val_loader, device)
    print(f"epoch {epoch} | validate | {metric}={val_score}")
    print(f"epoch {epoch} | train | {metric}={train_score}")
    # early stopping
    if val_score <= best_validation_score and n_tries > 0:
        n_tries -= 1
        continue
    elif val_score <= best_validation_score:
        logging.info(f"early stopping at epoch {epoch}")
        break
    best_validation_score = val_score
    best_weights = copy.deepcopy(model.state_dict())
    val_scores.append(val_score)
model.load_state_dict(best_weights)

train | loss = 0.10467: 100%|██████████| 40/40 [00:04<00:00,  8.18it/s]


epoch 0 | validate | accuracy=0.0042204547747546305
epoch 0 | train | accuracy=0.00422196992598897


train | loss = 0.08773: 100%|██████████| 40/40 [00:03<00:00, 11.00it/s]


epoch 1 | validate | accuracy=0.004227651745102117
epoch 1 | train | accuracy=0.0042284093188672215


train | loss = 0.07007: 100%|██████████| 40/40 [00:03<00:00, 10.88it/s]


epoch 2 | validate | accuracy=0.004252272959671573
epoch 2 | train | accuracy=0.004252272960200734


train | loss = 0.06839: 100%|██████████| 40/40 [00:03<00:00, 11.11it/s]


epoch 3 | validate | accuracy=0.004251515384318983
epoch 3 | train | accuracy=0.004250757808437228


train | loss = 0.10170: 100%|██████████| 40/40 [00:03<00:00, 10.78it/s]


epoch 4 | validate | accuracy=0.004221212350636384
epoch 4 | train | accuracy=0.004222727500812403


train | loss = 0.07839: 100%|██████████| 40/40 [00:04<00:00,  8.92it/s]


epoch 5 | validate | accuracy=0.004250757808437233
epoch 5 | train | accuracy=0.004250757809760136


train | loss = 0.07928: 100%|██████████| 40/40 [00:04<00:00,  9.78it/s]


epoch 6 | validate | accuracy=0.004252272961788216
epoch 6 | train | accuracy=0.004252272959671571


train | loss = 0.06622: 100%|██████████| 40/40 [00:03<00:00, 10.86it/s]


epoch 7 | validate | accuracy=0.004252272958084096
epoch 7 | train | accuracy=0.004252272960200744


train | loss = 0.07854: 100%|██████████| 40/40 [00:03<00:00, 10.93it/s]


epoch 8 | validate | accuracy=0.004251515384848147
epoch 8 | train | accuracy=0.004250757810553875


train | loss = 0.08138: 100%|██████████| 40/40 [00:03<00:00, 10.95it/s]


epoch 9 | validate | accuracy=0.004251515385906467
epoch 9 | train | accuracy=0.004251515384318985


train | loss = 0.08480: 100%|██████████| 40/40 [00:03<00:00, 10.83it/s]


epoch 10 | validate | accuracy=0.0042515153874939406
epoch 10 | train | accuracy=0.00425075780896639


train | loss = 0.07056: 100%|██████████| 40/40 [00:03<00:00, 10.91it/s]


epoch 11 | validate | accuracy=0.004250757807908074
epoch 11 | train | accuracy=0.004252272959671579


train | loss = 0.07120: 100%|██████████| 40/40 [00:03<00:00, 10.85it/s]


epoch 12 | validate | accuracy=0.004252272962846541
epoch 12 | train | accuracy=0.004251515385377305


train | loss = 0.06489: 100%|██████████| 40/40 [00:03<00:00, 10.84it/s]


epoch 13 | validate | accuracy=0.004251515386435628
epoch 13 | train | accuracy=0.004250757808437224


train | loss = 0.06909: 100%|██████████| 40/40 [00:03<00:00, 10.85it/s]


KeyboardInterrupt: ignored