In [96]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from tqdm import tqdm
import os
from datetime import datetime

In [97]:
class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [98]:
class Estimatio_LSTM(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,seq_len,num_layers):
        super().__init__()
        self.seq_len = seq_len
        self.num_layers =num_layers
        self.hidden_size = hidden_size
        self.input_size = input_size

        self.lstm = nn.LSTM(input_size, hidden_size,num_layers=num_layers, batch_first = True)
        self.fc = nn.Linear(hidden_size, output_size)

    
    def forward(self,x):
        # x [batch ,times_seires,input_size]
        out, _ = self.lstm(x)
        # out [batch,hidden_szie]
        out = self.fc(out[:,-1])
        return out
    def reset_hidden_state(self):
        self.hidden = (
            torch.zeros(self.num_layers, self.seq_len, self.hidden_size),
            torch.zeros(self.num_layers, self.seq_len, self.hidden_size)
        )

In [100]:
class Estimation_My():
    def __init__(self):

        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')
            print('Cuda is available')
        else:
            print('there is no Cuda')
            self.device = torch.device('cpu')

        torch.cuda.set_device(self.device)

    def normalization(self,data):
            mean = torch.mean(data, dim=0)
            std = torch.std(data, dim=0)
            result = 100 * (data-mean) / std
            
            return result
    
    def make_dir(self,log_dir,pt_dir):
        self.pt_dir = os.path.join(pt_dir,str(self.hidden_size))
        self.log_dir = log_dir
        if os.path.isdir(self.log_dir) == False:
            os.makedirs(self.log_dir)
        if os.path.isdir(self.pt_dir) == False:
            os.makedirs(self.pt_dir)

    def data_loader(self,Data_path,batch_size =32, seq_len=50,split = 0.2):
        self.seq_len = seq_len
        Data = pd.read_csv(Data_path)
        y = Data.pop(str(Data.columns[-1])).values


        Data = torch.tensor(Data.values, dtype=torch.float32)
        Data = self.normalization(Data)

        y = torch.tensor(y,dtype=torch.float32)
        y = self.normalization(y)
        data_X = []
        data_Y = []
        for i in range(0, len(Data)-seq_len):
            if i+seq_len > len(Data):
                break
            _x = Data[i:i+seq_len, :]
            _y = y[i:i+seq_len]

            data_X.append(_x)  
            data_Y.append(_y)
        print(len(data_Y))
        data_X = torch.FloatTensor(np.array(data_X))
        data_Y = torch.FloatTensor(np.array(data_Y))

        
        data_X = data_X.to(self.device)
        data_Y = data_Y.to(self.device)

        X_train, X_val = random_split(data_X,[int(len(data_X)*split),len(data_X)-int(len(data_X)*split)])
        y_train, y_val = random_split(data_Y,[int(len(data_X)*split),len(data_X)-int(len(data_X)*split)])
        train_dataset = CustomDataset(X_train,y_train)
        val_dataset = CustomDataset(X_val,y_val)

        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        self.val_loader = DataLoader(val_dataset, batch_size=batch_size)       

    def check_the_dataset(self):
        print(self.train_loader)
        print(next(iter(self.train_loader)))
    
    
    def model_setting(self,hidden_size,output_size=1,num_layers=1):
        self.hidden_size = hidden_size
        sample_batch = next(iter(self.train_loader))
        self.input_size = sample_batch[0].shape[2]
        print(self.input_size)
        self.model = Estimatio_LSTM(self.input_size,self.hidden_size,output_size,self.seq_len, num_layers)
        self.model.to(self.device)

    def train_setting(self,lr=1e-4, loss = nn.L1Loss()):
        self.criterion = loss
        self.lr = lr
        self.optimizer = optim.Adam(self.model.parameters(),lr=lr)
        
    
    def save_checkpoint(self, filename):
        torch.save(self.model, filename)

    
    def train(self,epoch):
        self.loss_log=[]
        for epoch in tqdm(range(epoch)):
            total_train_loss = 0
            self.model.train()
            for inputs, labels in self.train_loader:
                self.model.reset_hidden_state()
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                outputs = outputs.reshape((-1,1))
                train_loss = self.criterion(outputs, labels)
                train_loss.backward()
                self.optimizer.step()
                total_train_loss += train_loss.item()
            avg_train_loss =total_train_loss/len(self.train_loader)

            # Validation loop
            self.model.eval()
            with torch.no_grad():
                val_loss =0
                for inputs, labels in self.val_loader:
                    outputs = self.model(inputs)
                    val_loss += self.criterion(outputs, labels).item()
                avg_val_loss = val_loss/ len(self.val_loader)
            
            self.loss_log.append([avg_train_loss,avg_val_loss])

            print(f'Epoch {epoch+1}, Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}')
            self.save_checkpoint(os.path.join(self.pt_dir,f'model_epoch_{epoch}.pt'))

In [101]:
batch_size = 2
lr = 1e-4
hidden_size = 5
depth = 5
epoch = 500
#path
today = datetime.today()
Data_path = os.path.join('../Data/','TV_JM_145120.csv')
log_dir = os.path.join('../Data/ML',str(today.date()),'log') 
pt_dir = os.path.join('../Data/ML',str(today.date()),"pt")

model =  Estimation_My()
model.data_loader(Data_path,batch_size=batch_size)
model.model_setting(hidden_size=10)
model.train_setting()
model.make_dir(pt_dir=pt_dir,log_dir=log_dir)
model.train(epoch=100)

Cuda is available
2983
11


  return F.l1_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)
  1%|          | 1/100 [00:00<01:37,  1.02it/s]

Epoch 1, Loss: 78.6460, Validation Loss: 80.2296


  2%|▏         | 2/100 [00:01<01:28,  1.11it/s]

Epoch 2, Loss: 78.6389, Validation Loss: 80.2285


  3%|▎         | 3/100 [00:02<01:24,  1.15it/s]

Epoch 3, Loss: 78.6338, Validation Loss: 80.2276


  4%|▍         | 4/100 [00:03<01:21,  1.17it/s]

Epoch 4, Loss: 78.6290, Validation Loss: 80.2268


  5%|▌         | 5/100 [00:04<01:20,  1.19it/s]

Epoch 5, Loss: 78.6246, Validation Loss: 80.2261


  6%|▌         | 6/100 [00:05<01:18,  1.20it/s]

Epoch 6, Loss: 78.6218, Validation Loss: 80.2255


  7%|▋         | 7/100 [00:05<01:17,  1.20it/s]

Epoch 7, Loss: 78.6183, Validation Loss: 80.2251


  8%|▊         | 8/100 [00:06<01:16,  1.21it/s]

Epoch 8, Loss: 78.6150, Validation Loss: 80.2244


  9%|▉         | 9/100 [00:07<01:14,  1.21it/s]

Epoch 9, Loss: 78.6121, Validation Loss: 80.2238


 10%|█         | 10/100 [00:08<01:13,  1.22it/s]

Epoch 10, Loss: 78.6086, Validation Loss: 80.2236


 11%|█         | 11/100 [00:09<01:13,  1.22it/s]

Epoch 11, Loss: 78.6049, Validation Loss: 80.2227


 12%|█▏        | 12/100 [00:10<01:13,  1.20it/s]

Epoch 12, Loss: 78.6020, Validation Loss: 80.2215


 13%|█▎        | 13/100 [00:10<01:12,  1.19it/s]

Epoch 13, Loss: 78.5993, Validation Loss: 80.2212


 14%|█▍        | 14/100 [00:11<01:12,  1.18it/s]

Epoch 14, Loss: 78.5960, Validation Loss: 80.2205


 15%|█▌        | 15/100 [00:12<01:12,  1.17it/s]

Epoch 15, Loss: 78.5927, Validation Loss: 80.2196


 16%|█▌        | 16/100 [00:13<01:11,  1.17it/s]

Epoch 16, Loss: 78.5897, Validation Loss: 80.2186


 17%|█▋        | 17/100 [00:14<01:10,  1.18it/s]

Epoch 17, Loss: 78.5865, Validation Loss: 80.2177


 18%|█▊        | 18/100 [00:15<01:08,  1.19it/s]

Epoch 18, Loss: 78.5825, Validation Loss: 80.2167


 19%|█▉        | 19/100 [00:16<01:07,  1.20it/s]

Epoch 19, Loss: 78.5792, Validation Loss: 80.2162


 20%|██        | 20/100 [00:16<01:06,  1.20it/s]

Epoch 20, Loss: 78.5758, Validation Loss: 80.2151


 21%|██        | 21/100 [00:17<01:05,  1.21it/s]

Epoch 21, Loss: 78.5715, Validation Loss: 80.2143


 22%|██▏       | 22/100 [00:18<01:03,  1.22it/s]

Epoch 22, Loss: 78.5670, Validation Loss: 80.2139


 23%|██▎       | 23/100 [00:19<01:02,  1.23it/s]

Epoch 23, Loss: 78.5637, Validation Loss: 80.2133


 24%|██▍       | 24/100 [00:20<01:01,  1.24it/s]

Epoch 24, Loss: 78.5602, Validation Loss: 80.2123


 25%|██▌       | 25/100 [00:20<01:00,  1.24it/s]

Epoch 25, Loss: 78.5566, Validation Loss: 80.2113


 26%|██▌       | 26/100 [00:21<00:59,  1.24it/s]

Epoch 26, Loss: 78.5531, Validation Loss: 80.2114


 27%|██▋       | 27/100 [00:22<00:59,  1.23it/s]

Epoch 27, Loss: 78.5501, Validation Loss: 80.2097


 28%|██▊       | 28/100 [00:23<00:58,  1.22it/s]

Epoch 28, Loss: 78.5464, Validation Loss: 80.2093


 29%|██▉       | 29/100 [00:24<00:58,  1.21it/s]

Epoch 29, Loss: 78.5436, Validation Loss: 80.2082


 30%|███       | 30/100 [00:25<00:58,  1.20it/s]

Epoch 30, Loss: 78.5404, Validation Loss: 80.2081


 31%|███       | 31/100 [00:25<00:57,  1.20it/s]

Epoch 31, Loss: 78.5357, Validation Loss: 80.2082


 32%|███▏      | 32/100 [00:26<00:56,  1.20it/s]

Epoch 32, Loss: 78.5316, Validation Loss: 80.2070


 33%|███▎      | 33/100 [00:27<00:55,  1.20it/s]

Epoch 33, Loss: 78.5293, Validation Loss: 80.2066


 34%|███▍      | 34/100 [00:28<00:55,  1.20it/s]

Epoch 34, Loss: 78.5265, Validation Loss: 80.2060


 35%|███▌      | 35/100 [00:29<00:54,  1.19it/s]

Epoch 35, Loss: 78.5230, Validation Loss: 80.2050


 36%|███▌      | 36/100 [00:30<00:53,  1.20it/s]

Epoch 36, Loss: 78.5186, Validation Loss: 80.2042


 37%|███▋      | 37/100 [00:30<00:52,  1.20it/s]

Epoch 37, Loss: 78.5159, Validation Loss: 80.2038


 38%|███▊      | 38/100 [00:31<00:51,  1.20it/s]

Epoch 38, Loss: 78.5129, Validation Loss: 80.2033


 39%|███▉      | 39/100 [00:32<00:50,  1.21it/s]

Epoch 39, Loss: 78.5095, Validation Loss: 80.2022


 40%|████      | 40/100 [00:33<00:49,  1.22it/s]

Epoch 40, Loss: 78.5064, Validation Loss: 80.2015


 41%|████      | 41/100 [00:34<00:48,  1.22it/s]

Epoch 41, Loss: 78.5030, Validation Loss: 80.2006


 42%|████▏     | 42/100 [00:34<00:47,  1.23it/s]

Epoch 42, Loss: 78.4988, Validation Loss: 80.2002


 43%|████▎     | 43/100 [00:35<00:47,  1.20it/s]

Epoch 43, Loss: 78.4954, Validation Loss: 80.1983


 44%|████▍     | 44/100 [00:36<00:47,  1.18it/s]

Epoch 44, Loss: 78.4931, Validation Loss: 80.1980


 45%|████▌     | 45/100 [00:37<00:46,  1.18it/s]

Epoch 45, Loss: 78.4882, Validation Loss: 80.1974


 46%|████▌     | 46/100 [00:38<00:45,  1.18it/s]

Epoch 46, Loss: 78.4864, Validation Loss: 80.1968


 47%|████▋     | 47/100 [00:39<00:45,  1.17it/s]

Epoch 47, Loss: 78.4814, Validation Loss: 80.1959


 48%|████▊     | 48/100 [00:40<00:43,  1.18it/s]

Epoch 48, Loss: 78.4793, Validation Loss: 80.1947


 49%|████▉     | 49/100 [00:40<00:42,  1.20it/s]

Epoch 49, Loss: 78.4762, Validation Loss: 80.1938


 50%|█████     | 50/100 [00:41<00:42,  1.19it/s]

Epoch 50, Loss: 78.4726, Validation Loss: 80.1931


 51%|█████     | 51/100 [00:42<00:41,  1.19it/s]

Epoch 51, Loss: 78.4690, Validation Loss: 80.1925


 52%|█████▏    | 52/100 [00:43<00:39,  1.21it/s]

Epoch 52, Loss: 78.4641, Validation Loss: 80.1917


 53%|█████▎    | 53/100 [00:44<00:38,  1.21it/s]

Epoch 53, Loss: 78.4615, Validation Loss: 80.1914


 54%|█████▍    | 54/100 [00:45<00:38,  1.20it/s]

Epoch 54, Loss: 78.4579, Validation Loss: 80.1917


 55%|█████▌    | 55/100 [00:45<00:37,  1.21it/s]

Epoch 55, Loss: 78.4551, Validation Loss: 80.1901


 56%|█████▌    | 56/100 [00:46<00:36,  1.21it/s]

Epoch 56, Loss: 78.4498, Validation Loss: 80.1898


 57%|█████▋    | 57/100 [00:47<00:35,  1.21it/s]

Epoch 57, Loss: 78.4454, Validation Loss: 80.1886


 58%|█████▊    | 58/100 [00:48<00:34,  1.20it/s]

Epoch 58, Loss: 78.4401, Validation Loss: 80.1882


 59%|█████▉    | 59/100 [00:49<00:34,  1.20it/s]

Epoch 59, Loss: 78.4369, Validation Loss: 80.1878


 60%|██████    | 60/100 [00:50<00:33,  1.20it/s]

Epoch 60, Loss: 78.4335, Validation Loss: 80.1867


 61%|██████    | 61/100 [00:50<00:32,  1.20it/s]

Epoch 61, Loss: 78.4302, Validation Loss: 80.1857


 62%|██████▏   | 62/100 [00:51<00:31,  1.19it/s]

Epoch 62, Loss: 78.4259, Validation Loss: 80.1857


 63%|██████▎   | 63/100 [00:52<00:30,  1.20it/s]

Epoch 63, Loss: 78.4227, Validation Loss: 80.1856


 64%|██████▍   | 64/100 [00:53<00:29,  1.20it/s]

Epoch 64, Loss: 78.4186, Validation Loss: 80.1842


 65%|██████▌   | 65/100 [00:54<00:29,  1.20it/s]

Epoch 65, Loss: 78.4148, Validation Loss: 80.1840


 66%|██████▌   | 66/100 [00:55<00:28,  1.19it/s]

Epoch 66, Loss: 78.4106, Validation Loss: 80.1842


 67%|██████▋   | 67/100 [00:55<00:27,  1.19it/s]

Epoch 67, Loss: 78.4076, Validation Loss: 80.1830


 68%|██████▊   | 68/100 [00:56<00:27,  1.18it/s]

Epoch 68, Loss: 78.4037, Validation Loss: 80.1833


 69%|██████▉   | 69/100 [00:57<00:26,  1.17it/s]

Epoch 69, Loss: 78.3996, Validation Loss: 80.1811


 70%|███████   | 70/100 [00:58<00:25,  1.16it/s]

Epoch 70, Loss: 78.3971, Validation Loss: 80.1820


 71%|███████   | 71/100 [00:59<00:25,  1.16it/s]

Epoch 71, Loss: 78.3935, Validation Loss: 80.1814


 72%|███████▏  | 72/100 [01:00<00:24,  1.16it/s]

Epoch 72, Loss: 78.3892, Validation Loss: 80.1815


 73%|███████▎  | 73/100 [01:01<00:23,  1.17it/s]

Epoch 73, Loss: 78.3866, Validation Loss: 80.1807


 74%|███████▍  | 74/100 [01:01<00:21,  1.19it/s]

Epoch 74, Loss: 78.3824, Validation Loss: 80.1811


 75%|███████▌  | 75/100 [01:02<00:20,  1.22it/s]

Epoch 75, Loss: 78.3795, Validation Loss: 80.1805


 76%|███████▌  | 76/100 [01:03<00:19,  1.22it/s]

Epoch 76, Loss: 78.3757, Validation Loss: 80.1788


 77%|███████▋  | 77/100 [01:04<00:18,  1.24it/s]

Epoch 77, Loss: 78.3738, Validation Loss: 80.1786


 78%|███████▊  | 78/100 [01:05<00:17,  1.25it/s]

Epoch 78, Loss: 78.3689, Validation Loss: 80.1792


 79%|███████▉  | 79/100 [01:05<00:16,  1.25it/s]

Epoch 79, Loss: 78.3657, Validation Loss: 80.1792


 80%|████████  | 80/100 [01:06<00:15,  1.25it/s]

Epoch 80, Loss: 78.3633, Validation Loss: 80.1788


 81%|████████  | 81/100 [01:07<00:15,  1.26it/s]

Epoch 81, Loss: 78.3599, Validation Loss: 80.1780


 82%|████████▏ | 82/100 [01:08<00:14,  1.25it/s]

Epoch 82, Loss: 78.3559, Validation Loss: 80.1788


 83%|████████▎ | 83/100 [01:09<00:13,  1.25it/s]

Epoch 83, Loss: 78.3523, Validation Loss: 80.1782


 84%|████████▍ | 84/100 [01:09<00:12,  1.24it/s]

Epoch 84, Loss: 78.3490, Validation Loss: 80.1777


 85%|████████▌ | 85/100 [01:10<00:12,  1.24it/s]

Epoch 85, Loss: 78.3457, Validation Loss: 80.1768


 86%|████████▌ | 86/100 [01:11<00:11,  1.24it/s]

Epoch 86, Loss: 78.3427, Validation Loss: 80.1775


 87%|████████▋ | 87/100 [01:12<00:10,  1.24it/s]

Epoch 87, Loss: 78.3387, Validation Loss: 80.1760


 88%|████████▊ | 88/100 [01:13<00:09,  1.24it/s]

Epoch 88, Loss: 78.3347, Validation Loss: 80.1770


 89%|████████▉ | 89/100 [01:13<00:08,  1.24it/s]

Epoch 89, Loss: 78.3315, Validation Loss: 80.1745


 90%|█████████ | 90/100 [01:14<00:08,  1.25it/s]

Epoch 90, Loss: 78.3279, Validation Loss: 80.1739


 91%|█████████ | 91/100 [01:15<00:07,  1.24it/s]

Epoch 91, Loss: 78.3255, Validation Loss: 80.1733


 92%|█████████▏| 92/100 [01:16<00:06,  1.24it/s]

Epoch 92, Loss: 78.3219, Validation Loss: 80.1750


 93%|█████████▎| 93/100 [01:17<00:05,  1.24it/s]

Epoch 93, Loss: 78.3191, Validation Loss: 80.1737


 94%|█████████▍| 94/100 [01:17<00:04,  1.23it/s]

Epoch 94, Loss: 78.3167, Validation Loss: 80.1735


 95%|█████████▌| 95/100 [01:18<00:04,  1.22it/s]

Epoch 95, Loss: 78.3138, Validation Loss: 80.1716


 96%|█████████▌| 96/100 [01:19<00:03,  1.20it/s]

Epoch 96, Loss: 78.3120, Validation Loss: 80.1727


 97%|█████████▋| 97/100 [01:20<00:02,  1.19it/s]

Epoch 97, Loss: 78.3059, Validation Loss: 80.1726


 98%|█████████▊| 98/100 [01:21<00:01,  1.20it/s]

Epoch 98, Loss: 78.3025, Validation Loss: 80.1739


 99%|█████████▉| 99/100 [01:22<00:00,  1.20it/s]

Epoch 99, Loss: 78.3005, Validation Loss: 80.1726


100%|██████████| 100/100 [01:22<00:00,  1.21it/s]

Epoch 100, Loss: 78.2964, Validation Loss: 80.1717



