In [1]:
import math
import random
import yaml
import argparse
from dotmap import DotMap

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import Adam
from torch.nn.functional import cosine_similarity

import matplotlib.pyplot as plt
import wandb

In [2]:
import sys
sys.path.append("./src")  # make sure Python can find src/
import data
from model_linear import GPTLinear
from model_softmax import GPTSoftmax
# from train_step import train_step
from multi_task_train import train_step

In [3]:
with open("src/configs/mix1_mws_mwp.yaml", "r") as f:
    config = yaml.safe_load(f)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

with open("src/configs/mix1_mws_mwp.yaml", "r") as f:
    config = yaml.safe_load(f)

config = DotMap(config)
config.model.vocab_size = max(config.data.p, config.data.max_num) + 1 # Vocabn of model
config.model.block_size = 2 * config.data.num_tokens + 1 # Length of each sequence

In [5]:
num_task = len(config.data.tasks)
data_samplers = {}
for i in range(num_task):
    task = config.data.tasks[i]
    # print(config.data.tasks[i].name)
    task_class = getattr(data, task.name)
    data_samplers[task.name] = task_class(
            min_num=config.data.min_num,
        max_num=config.data.max_num,
        k=config.data.k,
        p=config.data.p,
        sep = task.sep,
    )

In [6]:
data_samplers

{'MovingWindowSum': <data.MovingWindowSum at 0x7f3b908c3580>,
 'MovingWindowProduct': <data.MovingWindowProduct at 0x7f3b908c36a0>}

In [11]:
## initialize model
if config.model.linear:
    model = GPTLinear(config.model, return_att=True).to(device)
else:
    model = GPTSoftmax(config.model, return_att=True).to(device)

optim = Adam(model.parameters(), lr=config.train.lr)

if config.train.wandb:
    wandb_run_name = config.train.wandb_run_name
    wandb.login(key="")
    wandb.init(project=config.train.wandb_project, name=wandb_run_name, config=config)
    wandb.watch(model)

for step in range(config.train.num_steps):
    train_step(
        model=model,
        optim=optim,
        data_samplers=data_samplers,
        step=step,
        config=config,
        device=device
    )
    
if config.train.wandb:
    wandb.finish()






Step 0 -- Train loss: 2.887983560562134, Train Acc: 0.062744140625 Test Acc: 0.0576171875
Step 1 -- Train loss: 2.871685266494751, Train Acc: 0.05908203125 Test Acc: 0.0498046875
Step 2 -- Train loss: 2.8500843048095703, Train Acc: 0.06640625 Test Acc: 0.064453125
Step 3 -- Train loss: 2.8298401832580566, Train Acc: 0.0654296875 Test Acc: 0.0537109375
Step 4 -- Train loss: 2.819368839263916, Train Acc: 0.06005859375 Test Acc: 0.0576171875
Step 5 -- Train loss: 2.8032753467559814, Train Acc: 0.058349609375 Test Acc: 0.05859375
Step 6 -- Train loss: 2.7883567810058594, Train Acc: 0.065185546875 Test Acc: 0.06640625
Step 7 -- Train loss: 2.7779343128204346, Train Acc: 0.064208984375 Test Acc: 0.0537109375
Step 8 -- Train loss: 2.774282217025757, Train Acc: 0.064453125 Test Acc: 0.0625
Step 9 -- Train loss: 2.7677390575408936, Train Acc: 0.063232421875 Test Acc: 0.0712890625
Step 10 -- Train loss: 2.762723207473755, Train Acc: 0.065185546875 Test Acc: 0.0634765625
Step 11 -- Train loss: 2.



0,1
att_prog_measure,▅▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁▂▂▂▂▃▄▄▅▆▇▇█
data_repeat_frac,▆▃▄▆▄▄▂▄▆▄▃▅▃▅▂▄▄▁▆▅▅▅▄▅▃▄▇▄▄▄▅▅▇█▄▅▄▃▅▃
idx0_check,▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▂▂▂▂▂▃▂▃▄▄▅▆▇▇▇▇███
idx10_check,▄▃▁▅▄▃▆▆▅▄▇▃▆▅▄█▆▃▅▃▅▁▇█▆▄▄▄▄▅▃▃▅▄▃▆▄▅▆▃
idx11_check,▄▆▃▄▆▃▅▂▄▄▄▃▃▃▂▃▃▄▃▃▅▅▃▅▄▄█▅▃▁▁▅▃▃▄▂▅▂▃▄
idx12_check,▃▄▆▃▆▄█▆▅▂▃▂▄▁▃▂▂▄▄▇▆▆▅▅▆▅▄▃▆▄▃▅▇▄▄█▄▇▄▄
idx13_check,▄▇▄▆▄▆▄▅▇▅▅▂▇▃▄▅▅▇▇▇▂▄▅▅▁▂▅▂▅▄█▆▄▄▇▃▅▇▇▅
idx14_check,▃▅▅▅▇▆▆▃▃▇▅▃▅▄▁▃▄▅▅▃▃▅▁▃▅▅▅▇▁▃▆▅▅▆▅▃█▄▃▃
idx15_check,█▅▆▄▃▄▆▆▆▆▅▄▇▂▇▄▄▄▇▅▇▇▅▄▅▂█▂▃▃▂▆▅▃▆▆▃▁▄▆
idx1_check,▆▂▃▃▁▆▄▆▅▃▅▇▁▆▅▂▃▃▆▄▄█▆▃▁▃▄▃▆▂▂▅▅▅▅▂██▃▃

0,1
att_prog_measure,0.08647
data_repeat_frac,0.05208
idx0_check,1.0
idx10_check,0.04688
idx11_check,0.05859
idx12_check,0.05469
idx13_check,0.05469
idx14_check,0.04688
idx15_check,0.07812
idx1_check,0.05469
