# Imports

In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import plotly.graph_objs as go
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from utils import TrainTestSequenceDataset, PredictSequenceDataset, smape, process_for_train, features, targets, process_for_predict
from models import LSTM, ResNet18
from tqdm.notebook import tqdm
from IPython.display import clear_output, display
import ipywidgets as widgets
import os
from spaceopt import SpaceOpt

# Loading data

In [None]:
data = pd.read_csv('train.csv', parse_dates=['epoch'])
test_sat_id = torch.load('test_sat_id')

# Data processing

In [None]:
sat_datas_train, sat_datas_test = process_for_train_test(data)

In [None]:
sat_datas_train[0].head()

# Model

In [None]:
seq_len = 20
hidden_dim = 100
num_layers = 2
model = LSTM(hidden_dim=hidden_dim, seq_len=seq_len, num_layers=num_layers)
model.train()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.02, eps=1e-2)
criterion = smape
n = 10

# Train on n satellite

In [None]:
model.train()
loss_widget = widgets.FloatProgress(min=0, max=1, step=0.01, description='Loss', value=0)  # jupyter widget
display(loss_widget)
train_data = sat_datas_train[n]
x_train = train_data[features]
y_train = train_data[targets]
train_dataset = TrainTestSequenceDataset(x_train, y_train, seq_len=model.seq_len)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
for epoch in tqdm(range(15)):
    for seq_train_x, seq_train_y in train_dataloader:
        model.zero_grad()  # refresh gradients
        model.init_hidden_cell()
        predictions = model(seq_train_x)
        loss = criterion(predictions, seq_train_y)
        loss_widget.value = loss.mean()
        loss.mean().backward()  # compute gradients
        optimizer.step()  # update network parameters

In [None]:
model.eval()
score_widget = widgets.FloatProgress(min=0, max=1, step=0.01, description='Score', value=0)  # jupyter widget
display(score_widget)
test_data = sat_datas_test[n]
x_test = test_data[features]
y_test = test_data[targets]
test_dataset = TrainTestSequenceDataset(x_test, y_test, seq_len=model.seq_len)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
loss_sum = 0
i = 0

for seq_test_x, seq_test_y in tqdm(test_dataloader):
    with torch.no_grad():
        model.init_hidden_cell()
        predictions = model(seq_test_x)
        loss = criterion(predictions, seq_test_y).mean()
        loss_sum += loss
        i += 1
        score_widget.value = 1 - loss_sum / i
score =  (1 - loss_sum / i).item()
print(score)

In [None]:
search_space = {
    'lr': [0.001, 0.01, 0.1, 1.],
    'eps': [1e-8, 1e-4, 1e-2, 1e0],
    'seq_len': [10, 20, 30, 40],
    'hidden_dim': [20, 30, 50],
    'epoch': [10],
    'num_layers': [1, 2, 3, 4]
}

# Models evaluation

In [None]:
def evaluate_new(spoint, sat_id):
    seq_len = spoint['seq_len']
    hidden_dim = spoint['hidden_dim']
    num_layers = spoint['num_layers']
    ep = spoint['epoch']
    lr = spoint['lr']
    eps = spoint['eps']
    
    train_data = sat_datas_train[sat_id]
    if seq_len > int(len(train_data) / 2 - 1):
        seq_len = int(len(train_data) / 2 - 1)
    x_train = train_data[features]
    y_train = train_data[targets]
    
    model = LSTM(hidden_dim=hidden_dim, seq_len=seq_len, num_layers=num_layers)
    model.train()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, eps=eps)
    criterion = smape
    model.train()
    loss_widget = widgets.FloatProgress(min=0, max=1, step=0.01, description='Loss', value=0)  # jupyter widget
    display(loss_widget)
    train_data = sat_datas_train[n]
    x_train = train_data[features]
    y_train = train_data[targets]
    train_dataset = TrainTestSequenceDataset(x_train, y_train, seq_len=model.seq_len)
    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    for epoch in tqdm(range(ep)):
        for seq_train_x, seq_train_y in train_dataloader:
            model.zero_grad()  # refresh gradients
            model.init_hidden_cell()
            predictions = model(seq_train_x)
            loss = criterion(predictions, seq_train_y)
            loss_widget.value = loss.mean()
            loss.mean().backward()  # compute gradients
            optimizer.step()  # update network parameters

    model.eval()
    score_widget = widgets.FloatProgress(min=0, max=1, step=0.01, description='Score', value=0)  # jupyter widget
    display(score_widget)
    test_data = sat_datas_test[n]
    x_test = test_data[features]
    y_test = test_data[targets]
    test_dataset = TrainTestSequenceDataset(x_test, y_test, seq_len=model.seq_len)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    loss_sum = 0
    i = 0

    for seq_test_x, seq_test_y in tqdm(test_dataloader):
        with torch.no_grad():
            model.init_hidden_cell()
            predictions = model(seq_test_x)
            loss = criterion(predictions, seq_test_y).mean()
            loss_sum += loss
            i += 1
            score_widget.value = 1 - loss_sum / i
    score =  (1 - loss_sum / i).item()

    return score, model

In [None]:
best_params = {'lr': 0.02, 'eps': 1e-2, 'seq_len': 20,
               'hidden_dim': 100, 'epoch': 30, 'num_layers': 2}

In [None]:
results = {}
for sat_id in test_sat_id:
    clear_output()
    print(f'Satellite id: {sat_id}')
    print(results)
    spoint = best_params
    
    score, model= evaluate_new(spoint=spoint, sat_id=sat_id)
    spoint['score'] = score
    if not os.path.exists(f'models//{sat_id}'):
        os.makedirs(f'models//{sat_id}')
    results[str(sat_id)] = score
    torch.save(score, f'models//{sat_id}//score')
    torch.save(model, f'models//{sat_id}//model.pt')

# Models factory

In [2]:
best_params = {'lr': 0.02, 'eps': 1e-2, 'seq_len': 20,
               'hidden_dim': 100, 'epoch': 30, 'num_layers': 2}

In [3]:
data = pd.read_csv('train.csv', parse_dates=['epoch'])
sat_datas_train = process_for_train(data)

In [4]:
sat_datas_train[0]

Unnamed: 0,epoch,x,y,z,Vx,Vy,Vz,x_sim,y_sim,z_sim,...,fi_sim,dro/dt_sim,dtheta/dt_sim,dfi/dt_sim,dx_sim,dy_sim,dz_sim,dro_sim,dtheta_sim,dfi_sim
0,-1.730756,-8855.823863,13117.780146,-20728.353233,-0.908303,-3.808436,-2.022083,-0.311082,0.500388,-0.999704,...,-1.291078,-0.004444,-0.937090,0.747771,-0.343891,-1.463405,-0.957448,-0.004444,-0.937090,0.747771
1,-1.727132,-10567.672384,1619.746066,-24451.813271,-0.302590,-4.272617,-0.612796,-0.378567,0.051732,-1.180247,...,-1.908559,0.434614,-1.353263,0.125439,-0.114952,-1.642185,-0.291497,0.434614,-1.353263,0.125439
2,-1.723508,-10578.684043,-10180.467460,-24238.280949,0.277435,-4.047522,0.723155,-0.379212,-0.408885,-1.170478,...,1.416777,0.757259,-1.006825,-0.538343,0.104417,-1.556416,0.340189,0.757259,-1.006825,-0.538343
3,-1.719884,-9148.251857,-20651.437460,-20720.381279,0.715600,-3.373762,1.722115,-0.323161,-0.817838,-1.000836,...,0.937656,0.927449,-0.680071,-0.604404,0.270316,-1.298146,0.813059,0.927449,-0.680071,-0.604404
4,-1.716260,-6719.092336,-28929.061629,-14938.907967,0.992507,-2.519732,2.344703,-0.227778,-1.141341,-0.721533,...,0.552240,0.972421,-0.534778,-0.546743,0.375320,-0.970269,1.108234,0.972421,-0.534778,-0.546743
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
953,1.715598,17337.596150,-3224.996803,40025.071742,-0.055615,2.650511,-0.222561,0.645359,0.471055,1.704354,...,1.487661,-0.742742,-0.766832,0.241463,-0.220788,0.958989,-0.678237,-0.742742,-0.766832,0.241463
954,1.719222,16849.590836,4217.959953,38636.167298,-0.295282,2.642711,-0.774030,0.566305,0.729596,1.468951,...,1.212410,-0.860830,-0.663148,0.348921,-0.321936,0.846475,-0.964490,-0.860830,-0.663148,0.348921
955,1.722845,15667.981809,11481.446566,35656.909015,-0.550136,2.518368,-1.356292,0.457647,0.945960,1.151565,...,0.911970,-0.949216,-0.586234,0.445073,-0.423032,0.659548,-1.247310,-0.949216,-0.586234,0.445073
956,1.726469,13754.838284,18199.705814,31013.052037,-0.816256,2.247835,-1.959266,0.320503,1.096206,0.756424,...,0.580306,-0.986231,-0.542090,0.545735,-0.515481,0.378531,-1.501082,-0.986231,-0.542090,0.545735


In [9]:
def train_new(spoint, sat_id):
    seq_len = spoint['seq_len']
    hidden_dim = spoint['hidden_dim']
    num_layers = spoint['num_layers']
    ep = spoint['epoch']
    lr = spoint['lr']
    eps = spoint['eps']
    
    train_data = sat_datas_train[sat_id]
    if seq_len > int(len(train_data) / 2 - 1):
        seq_len = int(len(train_data) / 2 - 1)
    x_train = train_data[features]
    y_train = train_data[targets]
    
    model = LSTM(hidden_dim=hidden_dim, seq_len=seq_len, num_layers=num_layers)
    model.train()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, eps=eps)
    criterion = smape
    model.train()
    loss_widget = widgets.FloatProgress(min=0, max=1, step=0.01, description='Loss', value=0)  # jupyter widget
    display(loss_widget)
    train_data = sat_datas_train[sat_id]
    x_train = train_data[features]
    y_train = train_data[targets]
    train_dataset = TrainTestSequenceDataset(x_train, y_train, seq_len=model.seq_len)
    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    for epoch in tqdm(range(ep)):
        for seq_train_x, seq_train_y in train_dataloader:
            model.zero_grad()  # refresh gradients
            model.init_hidden_cell()
            predictions = model(seq_train_x)
            loss = criterion(predictions, seq_train_y)
            loss_widget.value = loss.mean()
            loss.mean().backward()  # compute gradients
            optimizer.step()  # update network parameters

    model.eval()

    return model

In [10]:
for sat_id in range(600):
    clear_output()
    print(f'Satellite id: {sat_id}')
    spoint = best_params
    
    model= train_new(spoint=spoint, sat_id=sat_id)
    if not os.path.exists(f'models//{sat_id}'):
        os.makedirs(f'models//{sat_id}')
    
    torch.save(model, f'models//{sat_id}//model.pt')

Satellite id: 4


FloatProgress(value=0.0, description='Loss', max=1.0)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




KeyboardInterrupt: 