<a href="https://colab.research.google.com/github/cmarschner/scratch/blob/master/timeseries.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from pathlib import Path
import requests
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import nn
from torch import optim
from matplotlib import pyplot
import numpy as np
from typing import List
from sklearn.preprocessing import MinMaxScaler

In [0]:
DATA_PATH = Path("data")
PATH = DATA_PATH / "beijing-air-quality"

PATH.mkdir(parents=True, exist_ok=True)
# https://archive.ics.uci.edu/ml/datasets/Beijing+Multi-Site+Air-Quality+Data/
URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/00501/"
FILENAME = "PRSA2017_Data_20130301-20170228.zip"

if not (PATH / FILENAME).exists():
    content = requests.get(URL + FILENAME).content
    (PATH / FILENAME).open("wb").write(content)

In [0]:
!unzip -o {PATH}/{FILENAME} -d {PATH} 

Archive:  data/beijing-air-quality/PRSA2017_Data_20130301-20170228.zip
   creating: data/beijing-air-quality/PRSA_Data_20130301-20170228/
  inflating: data/beijing-air-quality/PRSA_Data_20130301-20170228/PRSA_Data_Aotizhongxin_20130301-20170228.csv  
  inflating: data/beijing-air-quality/PRSA_Data_20130301-20170228/PRSA_Data_Changping_20130301-20170228.csv  
  inflating: data/beijing-air-quality/PRSA_Data_20130301-20170228/PRSA_Data_Dingling_20130301-20170228.csv  
  inflating: data/beijing-air-quality/PRSA_Data_20130301-20170228/PRSA_Data_Dongsi_20130301-20170228.csv  
  inflating: data/beijing-air-quality/PRSA_Data_20130301-20170228/PRSA_Data_Guanyuan_20130301-20170228.csv  
  inflating: data/beijing-air-quality/PRSA_Data_20130301-20170228/PRSA_Data_Gucheng_20130301-20170228.csv  
  inflating: data/beijing-air-quality/PRSA_Data_20130301-20170228/PRSA_Data_Huairou_20130301-20170228.csv  
  inflating: data/beijing-air-quality/PRSA_Data_20130301-20170228/PRSA_Data_Nongzhanguan_20130301-

In [0]:
SUBDIR = PATH / "PRSA_Data_20130301-20170228"
files = ["PRSA_Data_Changping_20130301-20170228.csv", "PRSA_Data_Aotizhongxin_20130301-20170228.csv"]
import gzip
import pandas as pd

df = pd.read_csv(SUBDIR / files[0])

In [0]:
len(df)

35064

In [0]:
df

Unnamed: 0,No,year,month,day,hour,PM2.5,PM10,SO2,NO2,CO,O3,TEMP,PRES,DEWP,RAIN,wd,WSPM,station
0,1,2013,3,1,0,3.0,6.0,13.0,7.0,300.0,85.0,-2.3,1020.8,-19.7,0.0,E,0.5,Changping
1,2,2013,3,1,1,3.0,3.0,6.0,6.0,300.0,85.0,-2.5,1021.3,-19.0,0.0,ENE,0.7,Changping
2,3,2013,3,1,2,3.0,3.0,22.0,13.0,400.0,74.0,-3.0,1021.3,-19.9,0.0,ENE,0.2,Changping
3,4,2013,3,1,3,3.0,6.0,12.0,8.0,300.0,81.0,-3.6,1021.8,-19.1,0.0,NNE,1.0,Changping
4,5,2013,3,1,4,3.0,3.0,14.0,8.0,300.0,81.0,-3.5,1022.3,-19.4,0.0,N,2.1,Changping
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
35059,35060,2017,2,28,19,28.0,47.0,4.0,14.0,300.0,,11.7,1008.9,-13.3,0.0,NNE,1.3,Changping
35060,35061,2017,2,28,20,12.0,12.0,3.0,23.0,500.0,64.0,10.9,1009.0,-14.0,0.0,N,2.1,Changping
35061,35062,2017,2,28,21,7.0,23.0,5.0,17.0,500.0,68.0,9.5,1009.4,-13.0,0.0,N,1.5,Changping
35062,35063,2017,2,28,22,11.0,20.0,3.0,15.0,500.0,72.0,7.8,1009.6,-12.6,0.0,NW,1.4,Changping


In [0]:
dd = df[0:3][["PM2.5","PM10", "NO2", "CO", "O3", "TEMP", "PRES", "DEWP", "RAIN", "WSPM"]].to_numpy()
dd

array([[ 3.0000e+00,  6.0000e+00,  7.0000e+00,  3.0000e+02,  8.5000e+01,
        -2.3000e+00,  1.0208e+03, -1.9700e+01,  0.0000e+00,  5.0000e-01],
       [ 3.0000e+00,  3.0000e+00,  6.0000e+00,  3.0000e+02,  8.5000e+01,
        -2.5000e+00,  1.0213e+03, -1.9000e+01,  0.0000e+00,  7.0000e-01],
       [ 3.0000e+00,  3.0000e+00,  1.3000e+01,  4.0000e+02,  7.4000e+01,
        -3.0000e+00,  1.0213e+03, -1.9900e+01,  0.0000e+00,  2.0000e-01]])

In [0]:
dev = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")

In [0]:
dev

device(type='cuda')

In [0]:
from IPython.core.debugger import set_trace

mms = MinMaxScaler(feature_range=(-1,1))
mmsy = MinMaxScaler(feature_range=(-1,1))

class SequenceDataset(torch.utils.data.Dataset):
    """Dataset for time series prediction - cut out n elements and predict n+1.

    Arguments:
        filenames: csvs to read from
        x_cols: columns to use from the csv for training
        y_cols: columns for prediction
        seq_len: number of elements per sequence
        
    """
    def __init__(self, filenames: List[str], x_cols: List[str], y_cols: List[str], seq_len: int, is_test=False):
        print("reading ds...")
        self.df = pd.concat(pd.read_csv(filename) for filename in filenames)
        print("done...")
        self.df = self.df[list(set(x_cols) | set(y_cols))]  # only columns we are interested in
        all_lines = len(self.df)
        # Drop NaNs. Not ideal but right now the easiest way to get something trained.
        self.df = self.df[~np.isnan(self.df).any(axis=1)]
        print("Keeping only %d out of %d lines due to NaN entries" % (len(self.df), all_lines))
        self.x_cols = x_cols
        self.y_cols = y_cols
        self.seq_len = seq_len
        
        if not is_test:
            mms.fit(self.df[self.x_cols])
            mmsy.fit(self.df[self.y_cols])
        
    def __getitem__(self, index):
        # According to https://pytorch.org/docs/stable/data.html, don't put it on GPU yet.
        x = self.df[index:index + self.seq_len][self.x_cols].to_numpy(dtype=np.float32)
        y = self.df[index + self.seq_len:index + self.seq_len + 1][self.y_cols].to_numpy(dtype=np.float32)
        x = mms.transform(x)
        y = mmsy.transform(y)
        return (torch.tensor(x), torch.tensor(y))

    def __len__(self):
        return len(self.df) - self.seq_len - 1


In [0]:
full_ds = SequenceDataset([SUBDIR / x for x in files], 
                          x_cols=["PM2.5","PM10", "NO2", "CO", "O3", "TEMP", "PRES", "DEWP", "RAIN", "WSPM"], 
                          y_cols=["TEMP"],
                          seq_len=5)

reading ds...
done...
Keeping only 64886 out of 70128 lines due to NaN entries


In [0]:
train_ds, valid_ds = torch.utils.data.random_split(full_ds, [round(0.8 * len(full_ds)), round(0.2 * len(full_ds))])

In [0]:
print(len(train_ds), len(valid_ds))

51904 12976


In [0]:
def get_data(train_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )

In [0]:
def preprocess(x, y):
    # print("to cuda")
    xp, yp = x.to(dev), y.to(dev)
    return xp, yp

In [0]:
class SeqModel(nn.Module):
    """"""
    def __init__(self, input_size, lstm_cell_size, linear_size, output_size):
        super(SeqModel, self).__init__()
        self.lstm = nn.LSTM(input_size, lstm_cell_size, batch_first=True)
        self.lin = nn.Linear(lstm_cell_size, linear_size)
        self.lin2 = nn.Linear(linear_size, output_size)
        
    def forward(self, x):
        lstm_out, _ = self.lstm(x.clamp(-1, 1))
        last_lstm = lstm_out[:, -1, :]
        lin1_out = self.lin(last_lstm) # .sigmoid()
        line2_out = self.lin2(lin1_out)
        return line2_out

In [0]:
model_name = "airq-lstm"

In [0]:
bs = 64  # batch size
epochs = 10  # how many epochs to train for

model = SeqModel(input_size = len(full_ds.x_cols), lstm_cell_size=50, linear_size=50, output_size=1)
model.to(dev)
loss_func = nn.MSELoss(reduction='sum')
opt = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

In [0]:
def loss_batch(model, loss_func, xb, yb, opt=None):
    fwd = model(xb)
    loss = loss_func(fwd.squeeze(), yb.squeeze())
    if torch.isnan(loss).any():
        print("NAN!")
        print(xb)
        raise RuntimeError()
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

In [0]:
from tqdm import tqdm

def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in tqdm(range(epochs), desc="epoch",  total=epochs):
        model.train()
        train_losses = []
        lengths = []
        for i, (xb, yb) in tqdm(enumerate(train_dl), desc="train", total=len(train_dl), mininterval=5, miniters=100):
            assert type(xb) == torch.Tensor and type(yb) == torch.Tensor
            loss, lens = loss_batch(model, loss_func, xb, yb, opt)
            train_losses.append(loss)
            lengths.append(lens)
        train_loss = np.sum(np.multiply(train_losses, lengths)) / np.sum(lengths)
        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in tqdm(valid_dl, desc="valid", total=len(valid_dl), mininterval=5, miniters=100)]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print("epoch %d  train_loss %f val_loss %f" % (epoch, train_loss, val_loss))
        
        torch.save(model.state_dict(), "model.%s.%d.pl" % (model_name, epoch))

In [0]:
class WrappedDataLoader:
    """Applies func to every batch coming from the data loader."""
    def __init__(self, dl, func):
        self.dl = dl
        self.func = func

    def __len__(self):
        return len(self.dl)

    def __iter__(self):
        batches = iter(self.dl)
        for b in batches:
            ret = (self.func(*b))
            yield ret

In [0]:
(train_dl, valid_dl) = get_data(train_ds, bs)

In [0]:
train_dl = WrappedDataLoader(train_dl, preprocess)
valid_dl = WrappedDataLoader(valid_dl, preprocess)

In [0]:
fit(epochs, model, loss_func, opt, train_dl, valid_dl)

epoch:   0%|          | 0/10 [00:00<?, ?it/s]
train:   0%|          | 0/811 [00:00<?, ?it/s][A
train:  12%|█▏        | 100/811 [00:10<01:16,  9.30it/s][A
train:  25%|██▍       | 200/811 [00:21<01:05,  9.31it/s][A
train:  37%|███▋      | 300/811 [00:32<00:54,  9.33it/s][A
train:  49%|████▉     | 400/811 [00:42<00:44,  9.34it/s][A
train:  62%|██████▏   | 500/811 [00:53<00:33,  9.37it/s][A
train:  74%|███████▍  | 600/811 [01:04<00:22,  9.37it/s][A
train:  86%|████████▋ | 700/811 [01:14<00:11,  9.36it/s][A
train:  99%|█████████▊| 800/811 [01:25<00:01,  9.39it/s][A
train: 100%|██████████| 811/811 [01:26<00:00,  9.37it/s][A
valid:   0%|          | 0/102 [00:00<?, ?it/s][A
valid:   0%|          | 0/102 [00:13<?, ?it/s][A
valid:  65%|██████▍   | 66/102 [00:13<00:07,  4.84it/s][A
valid:  88%|████████▊ | 90/102 [00:18<00:02,  4.82it/s][A
epoch:  10%|█         | 1/10 [01:47<16:08, 107.65s/it]
train:   0%|          | 0/811 [00:00<?, ?it/s][A

epoch 0  train_loss 1.420674 val_loss 0.961226



train:  12%|█▏        | 100/811 [00:10<01:15,  9.42it/s][A
train:  25%|██▍       | 200/811 [00:21<01:04,  9.41it/s][A
train:  37%|███▋      | 300/811 [00:31<00:54,  9.39it/s][A
train:  37%|███▋      | 300/811 [00:42<00:54,  9.39it/s][A
train:  49%|████▉     | 398/811 [00:42<00:44,  9.37it/s][A
train:  55%|█████▍    | 446/811 [00:47<00:38,  9.40it/s][A
train:  61%|██████    | 494/811 [00:52<00:33,  9.41it/s][A
train:  67%|██████▋   | 540/811 [00:57<00:29,  9.33it/s][A
train:  73%|███████▎  | 588/811 [01:02<00:23,  9.39it/s][A
train:  78%|███████▊  | 636/811 [01:07<00:18,  9.41it/s][A
train:  84%|████████▍ | 683/811 [01:12<00:13,  9.41it/s][A
train:  90%|█████████ | 731/811 [01:17<00:08,  9.45it/s][A
train:  96%|█████████▌| 779/811 [01:22<00:03,  9.46it/s][A
train: 100%|██████████| 811/811 [01:26<00:00,  9.39it/s][A
valid:   0%|          | 0/102 [00:00<?, ?it/s][A
valid:   0%|          | 0/102 [00:16<?, ?it/s][A
valid:  77%|███████▋  | 79/102 [00:16<00:04,  4.85it/s][A


epoch 1  train_loss 0.361901 val_loss 0.574774



train:  12%|█▏        | 100/811 [00:10<01:15,  9.41it/s][A
train:  25%|██▍       | 200/811 [00:21<01:04,  9.42it/s][A
train:  37%|███▋      | 300/811 [00:31<00:54,  9.45it/s][A
train:  49%|████▉     | 400/811 [00:42<00:43,  9.44it/s][A
train:  62%|██████▏   | 500/811 [00:53<00:33,  9.41it/s][A
train:  74%|███████▍  | 600/811 [01:03<00:22,  9.44it/s][A
train:  86%|████████▋ | 700/811 [01:14<00:11,  9.44it/s][A
train:  99%|█████████▊| 800/811 [01:24<00:01,  9.47it/s][A
train: 100%|██████████| 811/811 [01:25<00:00,  9.45it/s][A
valid:   0%|          | 0/102 [00:00<?, ?it/s][A
valid:   0%|          | 0/102 [00:19<?, ?it/s][A
valid:  92%|█████████▏| 94/102 [00:19<00:01,  4.82it/s][A
epoch:  30%|███       | 3/10 [05:21<12:31, 107.32s/it]
train:   0%|          | 0/811 [00:00<?, ?it/s][A

epoch 2  train_loss 0.258370 val_loss 0.458386



train:  12%|█▏        | 100/811 [00:10<01:14,  9.52it/s][A
train:  25%|██▍       | 200/811 [00:21<01:04,  9.50it/s][A
train:  37%|███▋      | 300/811 [00:31<00:53,  9.51it/s][A
train:  49%|████▉     | 400/811 [00:42<00:43,  9.50it/s][A
train:  62%|██████▏   | 500/811 [00:52<00:32,  9.50it/s][A
train:  74%|███████▍  | 600/811 [01:03<00:22,  9.49it/s][A
train:  86%|████████▋ | 700/811 [01:13<00:11,  9.46it/s][A
train:  99%|█████████▊| 800/811 [01:24<00:01,  9.46it/s][A
train: 100%|██████████| 811/811 [01:25<00:00,  9.48it/s][A
valid:   0%|          | 0/102 [00:00<?, ?it/s][A
valid:   0%|          | 0/102 [00:12<?, ?it/s][A
valid:  61%|██████    | 62/102 [00:12<00:08,  4.85it/s][A
valid:  85%|████████▌ | 87/102 [00:17<00:03,  4.85it/s][A
epoch:  40%|████      | 4/10 [07:08<10:42, 107.08s/it]
train:   0%|          | 0/811 [00:00<?, ?it/s][A

epoch 3  train_loss 0.216282 val_loss 0.394293



train:  12%|█▏        | 100/811 [00:10<01:15,  9.44it/s][A
train:  25%|██▍       | 200/811 [00:21<01:04,  9.44it/s][A
train:  37%|███▋      | 300/811 [00:31<00:54,  9.45it/s][A
train:  37%|███▋      | 300/811 [00:41<00:54,  9.45it/s][A
train:  49%|████▊     | 395/811 [00:41<00:44,  9.42it/s][A
train:  55%|█████▍    | 443/811 [00:46<00:38,  9.44it/s][A
train:  60%|██████    | 490/811 [00:51<00:34,  9.43it/s][A
train:  66%|██████▋   | 538/811 [00:56<00:28,  9.46it/s][A
train:  72%|███████▏  | 586/811 [01:02<00:23,  9.46it/s][A
train:  78%|███████▊  | 633/811 [01:07<00:18,  9.42it/s][A
train:  84%|████████▍ | 681/811 [01:12<00:13,  9.45it/s][A
train:  90%|████████▉ | 728/811 [01:17<00:08,  9.42it/s][A
train:  96%|█████████▌| 776/811 [01:22<00:03,  9.43it/s][A
train: 100%|██████████| 811/811 [01:25<00:00,  9.44it/s][A
valid:   0%|          | 0/102 [00:00<?, ?it/s][A
valid:   0%|          | 0/102 [00:15<?, ?it/s][A
valid:  75%|███████▌  | 77/102 [00:15<00:05,  4.83it/s][A


epoch 4  train_loss 0.189269 val_loss 0.363356



train:  12%|█▏        | 100/811 [00:10<01:16,  9.32it/s][A
train:  25%|██▍       | 200/811 [00:21<01:05,  9.30it/s][A
train:  37%|███▋      | 300/811 [00:32<00:54,  9.31it/s][A
train:  49%|████▉     | 400/811 [00:43<00:44,  9.27it/s][A
train:  62%|██████▏   | 500/811 [00:53<00:33,  9.31it/s][A
train:  74%|███████▍  | 600/811 [01:04<00:22,  9.33it/s][A
train:  74%|███████▍  | 600/811 [01:14<00:22,  9.33it/s][A
train:  86%|████████▌ | 699/811 [01:15<00:12,  9.33it/s][A
train:  92%|█████████▏| 746/811 [01:20<00:06,  9.30it/s][A
train:  98%|█████████▊| 793/811 [01:25<00:01,  9.32it/s][A
train: 100%|██████████| 811/811 [01:27<00:00,  9.31it/s][A
valid:   0%|          | 0/102 [00:00<?, ?it/s][A
valid:   0%|          | 0/102 [00:17<?, ?it/s][A
valid:  84%|████████▍ | 86/102 [00:17<00:03,  4.79it/s][A
epoch:  60%|██████    | 6/10 [10:43<07:09, 107.42s/it]
train:   0%|          | 0/811 [00:00<?, ?it/s][A

epoch 5  train_loss 0.169074 val_loss 0.316692



train:  12%|█▏        | 100/811 [00:10<01:15,  9.42it/s][A
train:  25%|██▍       | 200/811 [00:21<01:04,  9.44it/s][A
train:  37%|███▋      | 300/811 [00:31<00:54,  9.41it/s][A
train:  49%|████▉     | 400/811 [00:42<00:43,  9.43it/s][A
train:  62%|██████▏   | 500/811 [00:52<00:32,  9.44it/s][A
train:  74%|███████▍  | 600/811 [01:03<00:22,  9.39it/s][A
train:  86%|████████▋ | 700/811 [01:14<00:11,  9.40it/s][A
train:  99%|█████████▊| 800/811 [01:24<00:01,  9.42it/s][A
train: 100%|██████████| 811/811 [01:26<00:00,  9.42it/s][A
valid:   0%|          | 0/102 [00:00<?, ?it/s][A
valid:   0%|          | 0/102 [00:10<?, ?it/s][A
valid:  50%|█████     | 51/102 [00:10<00:10,  4.78it/s][A
valid:  75%|███████▍  | 76/102 [00:15<00:05,  4.79it/s][A
valid:  99%|█████████▉| 101/102 [00:21<00:00,  4.80it/s][A
epoch:  70%|███████   | 7/10 [12:30<05:22, 107.37s/it]
train:   0%|          | 0/811 [00:00<?, ?it/s][A

epoch 6  train_loss 0.154187 val_loss 0.290622



train:  12%|█▏        | 100/811 [00:10<01:15,  9.48it/s][A
train:  25%|██▍       | 200/811 [00:21<01:04,  9.43it/s][A
train:  37%|███▋      | 300/811 [00:31<00:54,  9.43it/s][A
train:  49%|████▉     | 400/811 [00:42<00:43,  9.39it/s][A
train:  62%|██████▏   | 500/811 [00:53<00:33,  9.40it/s][A
train:  74%|███████▍  | 600/811 [01:03<00:22,  9.43it/s][A
train:  86%|████████▋ | 700/811 [01:14<00:11,  9.43it/s][A
train:  99%|█████████▊| 800/811 [01:24<00:01,  9.46it/s][A
train: 100%|██████████| 811/811 [01:26<00:00,  9.43it/s][A
valid:   0%|          | 0/102 [00:00<?, ?it/s][A
valid:   0%|          | 0/102 [00:13<?, ?it/s][A
valid:  65%|██████▍   | 66/102 [00:13<00:07,  4.87it/s][A
valid:  88%|████████▊ | 90/102 [00:18<00:02,  4.84it/s][A
epoch:  80%|████████  | 8/10 [14:17<03:34, 107.26s/it]
train:   0%|          | 0/811 [00:00<?, ?it/s][A

epoch 7  train_loss 0.141752 val_loss 0.269654



train:  12%|█▏        | 100/811 [00:10<01:16,  9.30it/s][A
train:  25%|██▍       | 200/811 [00:21<01:05,  9.36it/s][A
train:  37%|███▋      | 300/811 [00:32<00:54,  9.35it/s][A
train:  49%|████▉     | 400/811 [00:42<00:43,  9.41it/s][A
train:  49%|████▉     | 400/811 [00:52<00:43,  9.41it/s][A
train:  61%|██████▏   | 497/811 [00:52<00:33,  9.45it/s][A
train:  67%|██████▋   | 544/811 [00:57<00:28,  9.43it/s][A
train:  73%|███████▎  | 592/811 [01:02<00:23,  9.46it/s][A
train:  79%|███████▉  | 640/811 [01:07<00:18,  9.48it/s][A
train:  85%|████████▍ | 688/811 [01:12<00:12,  9.48it/s][A
train:  91%|█████████ | 736/811 [01:17<00:07,  9.51it/s][A
train:  97%|█████████▋| 784/811 [01:22<00:02,  9.52it/s][A
train: 100%|██████████| 811/811 [01:25<00:00,  9.46it/s][A
valid:   0%|          | 0/102 [00:00<?, ?it/s][A
valid:   0%|          | 0/102 [00:16<?, ?it/s][A
valid:  80%|████████  | 82/102 [00:16<00:04,  4.84it/s][A
epoch:  90%|█████████ | 9/10 [16:04<01:47, 107.06s/it]
train

epoch 8  train_loss 0.131986 val_loss 0.259571



train:  12%|█▏        | 100/811 [00:10<01:15,  9.46it/s][A
train:  25%|██▍       | 200/811 [00:21<01:04,  9.46it/s][A
train:  37%|███▋      | 300/811 [00:31<00:53,  9.48it/s][A
train:  49%|████▉     | 400/811 [00:42<00:43,  9.44it/s][A
train:  62%|██████▏   | 500/811 [00:52<00:32,  9.44it/s][A
train:  74%|███████▍  | 600/811 [01:03<00:22,  9.44it/s][A
train:  86%|████████▋ | 700/811 [01:14<00:11,  9.40it/s][A
train:  99%|█████████▊| 800/811 [01:24<00:01,  9.41it/s][A
train: 100%|██████████| 811/811 [01:26<00:00,  9.43it/s][A
valid:   0%|          | 0/102 [00:00<?, ?it/s][A
valid:   0%|          | 0/102 [00:19<?, ?it/s][A
valid:  95%|█████████▌| 97/102 [00:20<00:01,  4.84it/s][A
epoch: 100%|██████████| 10/10 [17:51<00:00, 107.05s/it]

epoch 9  train_loss 0.124371 val_loss 0.240409



