# Imports

In [1]:
import sys

import torch
import wandb

sys.path.append('../.')

from torch.utils.data import DataLoader,Dataset
from torch import nn, optim
from torch.functional import F
from myutils.template import train_model
from myutils.python_and_os import ignore_warnings

ignore_warnings()

# Config

In [2]:
config = {
    'project': 'LSTM',
    'lr': 1e-3,
    'epoch': 30,
    'batch_size':1,
    'dataloader_shuffle': True,
    'momentum': 0.9,
    'device': 'cuda',
    'save': True,
    'save_dir': "epoch_{}_{:.3f}.pth",
    'LSTM_window':20
}

# Model define

In [3]:
class MyLSTM(nn.Module):

    def __init__(self,input_dem:int,hs_dem:int):

        super(MyLSTM,self).__init__()

        self.hs_dem = hs_dem

        # none grad parameter
        self.cell_state = torch.zeros(hs_dem)
        self.hidden_state = torch.zeros(hs_dem)

        # require grad parameter
        self.w_xi = nn.Parameter(torch.randn(input_dem,hs_dem),requires_grad=True)
        self.w_xf = nn.Parameter(torch.randn(input_dem,hs_dem),requires_grad=True)
        self.w_xo = nn.Parameter(torch.randn(input_dem,hs_dem),requires_grad=True)
        self.w_xc = nn.Parameter(torch.randn(input_dem,hs_dem),requires_grad=True)

        self.w_hi = nn.Parameter(torch.randn(hs_dem,hs_dem),requires_grad=True)
        self.w_hf = nn.Parameter(torch.randn(hs_dem,hs_dem),requires_grad=True)
        self.w_ho = nn.Parameter(torch.randn(hs_dem,hs_dem),requires_grad=True)
        self.w_hc = nn.Parameter(torch.randn(hs_dem,hs_dem),requires_grad=True)

        self.b_i = nn.Parameter(torch.zeros(hs_dem),requires_grad=True)
        self.b_f = nn.Parameter(torch.zeros(hs_dem),requires_grad=True)
        self.b_o = nn.Parameter(torch.zeros(hs_dem),requires_grad=True)
        self.b_c = nn.Parameter(torch.zeros(hs_dem),requires_grad=True)

    def _forward(self,x):
        i = F.sigmoid(torch.matmul(x,self.w_xi) + torch.matmul(self.hidden_state,self.w_hi) + self.b_i)
        f = F.sigmoid(torch.matmul(x,self.w_xf) + torch.matmul(self.hidden_state,self.w_hf) + self.b_f)
        o = F.sigmoid(torch.matmul(x,self.w_xo) + torch.matmul(self.hidden_state,self.w_ho) + self.b_o)
        c = F.tanh(torch.matmul(x,self.w_xc) + torch.matmul(self.hidden_state,self.w_hc) + self.b_c)

        self.cell_state = torch.mul(f,self.cell_state) + torch.dot(i,c)
        self.hidden_state = torch.mul(o,torch.tanh(self.cell_state))

        return self.hidden_state

    def forward(self,x):
        return self._forward(x)

    def clean_states(self):
        self.cell_state = torch.zeros(self.hs_dem)
        self.hidden_state = torch.zeros(self.hs_dem)

class MySinPreModel(nn.Module):
    def __init__(self,input_num:int,input_dem:int,hs_dem:int):
        super(MySinPreModel,self).__init__()
        self.input_num = input_num
        self.lstm = MyLSTM(input_dem=input_dem,hs_dem=hs_dem)
        self.mlp = nn.Sequential(
            nn.Linear(hs_dem,1000),
            nn.ReLU(),
            nn.Linear(1000,100),
            nn.ReLU(),
            nn.Linear(100,1)
        )

    def forward(self,x):
        for index in range(self.input_num):
            self.lstm(x[index])
        return_value = self.mlp(self.lstm.hidden_state)
        self.lstm.clean_states()
        return return_value

# Preprocess

# Dataset define

In [None]:
class MySinDataSet(Dataset):
    def __init__(self,input_size:int,dataset_size:int,start:float = None,end:float = None):
        import numpy as np
        self.dataset_size = dataset_size
        self.input_size = input_size
        self.x = np.linspace(start=start,stop=end,num=dataset_size+input_size)
        self.data = np.sin(self.x)
    def __len__(self):
        return self.dataset_size
    def __getitem__(self, item):
        return torch.asarray(self.data[item:item+self.input_size]),torch.asarray(self.data[item+self.input_size+1])

# DataLoader define

In [5]:
train_dataloader = DataLoader(dataset=MySinDataSet(config['LSTM_window'],100000,start=0,end=1000),batch_size=config['batch_size'],shuffle=True)
test_dataloader = DataLoader(dataset=MySinDataSet(config['LSTM_window'],10000,start=1001,end=1101),batch_size=config['batch_size'],shuffle=True)
valid_dataloader = DataLoader(dataset=MySinDataSet(config['LSTM_window'],10000,start=1102,end=1202),batch_size=config['batch_size'],shuffle=True)

# Model initial

In [6]:
model = MySinPreModel(config['LSTM_window'],1,100)

# Loss Func and Optimizer

## Fine tune param select

In [7]:
# other_param = []
# layer4_param = []
# fc_param = []
# i = 0
#
# for name, param in  model.named_parameters():
#     if "layer4" in name:
#         layer4_param.append(param)
#     elif "fc" in name:
#         fc_param.append(param)
#     else:
#         other_param.append(param)
#     i += 1

# Loss and Optimizer

In [8]:
loss_func = nn.MSELoss()
loss_func = loss_func.to(config['device']) # using cuda
optimizer = optim.SGD(model.parameters(),config['lr'],momentum=0.9)

# Train loop

In [9]:
# # test process define
# def test_model(config:dict,model:nn.Module,test_dataLoader:DataLoader): #TODO: complete the test template code
#     model.eval()
#     with torch.no_grad():
#
#         num_total = 0
#         acc_total = 0
#         for i,(input,target) in enumerate(test_dataLoader):
#             input = input.to(config['device'])
#             output:torch.Tensor = model(input)
#             out_index = F.softmax(output).argmax(dim=1)
#             target = target.to(config['device'])
#             if (out_index == target)[0].item():
#                 acc_total += 1
#             num_total += 1
#
#         wandb.log({
#             'test_acc': acc_total/num_total
#         })
#         return acc_total/num_total

# process after model output and before loss func
def output_process(output:torch.Tensor):
    return output

wandb.init(project=config['project'],config=config)
for epoch in range(config['epoch']):
    epoch_loss = train_model(config=config,model=model,data_loader=train_dataloader,loss_func=loss_func,optimizer=optimizer,epoch_num=epoch,output_process=output_process)


[34m[1mwandb[0m: Currently logged in as: [33mgeraltigas[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0:   0%|          | 0/100000 [00:00<?, ?it/s]


RuntimeError: size mismatch, got 100, 100x1,20