# Sunspots Data Predict

## Colab Setting

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

## Import Library

In [None]:
!pip install pytorch-lightning

In [None]:
%cd drive/MyDrive/Research3/train

In [None]:
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import linalg as LA
from torch.nn import functional as F

import pytorch_lightning as pl
from pytorch_lightning import loggers, callbacks
from pytorch_lightning.metrics import functional as FM

import GetData
from BaseModel import BaseModel

## Define Model

In [None]:
class Layer(torch.nn.Module):
    def __init__(self, n_seq, n_input, n_hidden):
        super().__init__()
        self.n_seq = n_seq
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.rnncell = torch.nn.GRUCell(n_input, n_hidden)
        self.parent = nn.parameter.Parameter(-torch.ones(n_seq, dtype=torch.long), requires_grad=False)
        self.childs = self._construct_childs()
        self.score = torch.zeros((n_seq, n_seq))
        
    def forward(self, X):
        if (X.dim() != 3) or (X.size(0) != self.n_seq) or (X.size(2) != self.n_input): raise ValueError()
        hx = torch.empty((self.n_seq, X.size(1), self.n_hidden), dtype=X.dtype, device=X.device)
        nchilds = [len(child) for child in self.childs]
        
        q = []
        for idx, val in enumerate(nchilds):
            if val == 0: q.append(idx)
        while len(q) > 0:
            agg = torch.cat([self._aggregate(hx[self.childs[i]]) for i in q], 0)
            thx = torch.cat([X[i] for i in q], 0)
            res = self.rnncell(thx, agg).reshape(len(q), X.size(1), self.n_hidden)
            for idx, val in enumerate(q): hx[val] = res[idx]
            p = []
            for idx in q:
                if self.parent[idx] == -1: continue
                nchilds[self.parent[idx]] -= 1
                if nchilds[self.parent[idx]] == 0:
                    p.append(self.parent[idx])
            q = p
            
        self.score = self.score.to(X)
        with torch.no_grad(): self.score += self._dot_product(hx)
        return hx, hx[-1]

    def add_edge(self):
        old, new = self._get_index()
        self.parent[new] = old
        self.childs = self._construct_childs()
        self.score = torch.zeros((self.n_seq, self.n_seq))
        
    def _construct_childs(self):
        childs = [[] for i in range(self.n_seq)]
        for idx in range(self.n_seq):
            if self.parent[idx] >= 0:
                childs[self.parent[idx]].append(idx)
        return childs

    def _connected_end(self):
        connected = (self.parent != -1)
        connected[-1] = True
        return connected
    
    def _aggregate(self, h):
        if h.size(0) == 0: return torch.zeros(h.shape[1:], dtype=h.dtype, device=h.device)
        else: return h.mean(axis=0)
    
    def _dot_product(self, hx):
        hx_norm = torch.div(hx, LA.norm(hx, dim=2).reshape(hx.shape[0], hx.shape[1], -1))
        return torch.tensordot(hx_norm, hx_norm, dims=([1, 2], [1, 2]))
    
    def _get_index(self):
        old = self._connected_end()
        mask = torch.ones_like(self.score, dtype=torch.bool).tril(diagonal=-1)
        mask[old != True, :] = False
        mask[:, old] = False
        maskedscore = self.score[mask] / torch.sqrt(self.score[0][0])  
        cumask = torch.cumsum(mask, dim=1)
        cucumask = torch.cumsum(cumask[:, -1], dim=0)

        maskedscore = torch.nn.functional.softmax(maskedscore, dim=0)
        #plt.plot(maskedscore.cpu())
        #plt.show()
        maskedscore = torch.cumsum(maskedscore, dim=0)
        selectcursor = torch.rand(1).to(maskedscore)
        selectedidx = (maskedscore < selectcursor).sum()
        #selectedidx = maskedscore.argmax()

        row = (cucumask <= selectedidx).sum()
        if row > 0: selectedidx -= cucumask[row-1]
        col = (cumask[row] <= selectedidx).sum()
        #print(selectedidx, row, col)
        return row, col

In [None]:
class GRU(torch.nn.Module):
    def __init__(self, n_seq, n_input, n_hidden):
        super().__init__()
        self.n_seq = n_seq
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.rnncell = torch.nn.GRUCell(n_input, n_hidden)
    
    def forward(self, X):
        if (X.dim() != 3) or (X.size(0) != self.n_seq) or (X.size(2) != self.n_input): raise ValueError()
        hx = torch.empty((self.n_seq, X.size(1), self.n_hidden), dtype=X.dtype, device=X.device)
        h = torch.zeros((X.size(1), self.n_hidden), dtype=X.dtype, device=X.device)
        for i in range(self.n_seq):
            h = self.rnncell(X[i], h)
            hx[i] = h
        return hx, h

In [None]:
class Model(BaseModel):
    def __init__(self, rnn, n_input, n_hidden, n_x, n_y, batch_size, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.subNN = issubclass(rnn, GRU)
        self.rnn = rnn(n_x, n_input, n_hidden)
        self.fc = nn.Linear(n_hidden, n_y)
        
    def forward(self, X):
        X, hn = self.rnn(X.transpose(0, 1))
        return self.fc(X[-1])

    def training_epoch_end(self, outputs):
        if not (self.subNN or self.rnn._connected_end().all()):
            self.rnn.add_edge()
    
    def test_epoch_end(self, outputs):
        h, y = zip(*outputs)
        h = torch.cat(h, 0).cpu().numpy()
        y = torch.cat(y, 0).cpu().numpy()
        plt.scatter(y, h, s=1)
        plt.show()

## Train And Predict

In [None]:
n_x = 11*12#*10
n_y = 1 #10
batch_size = 128
ex_name = 'sunspots'
metrics_fn = [(FM.mean_absolute_error, 'mae')]
tb_logger = loggers.TensorBoardLogger('TB_logs', name=ex_name)
csv_logger = loggers.CSVLogger('CSV_logs', name=ex_name)
loggers_arr = [tb_logger, csv_logger]

In [None]:
data_module = GetData.SunspotData(n_x, n_y, batch_size=batch_size, shuffle=False)
model = Model(Layer, 1, 16, n_x, n_y, batch_size, F.mse_loss, metrics_fn=metrics_fn)
trainer = pl.Trainer(gpus=1, max_epochs=100, progress_bar_refresh_rate=30, logger=loggers_arr, weights_save_path='lightning_logs')

In [None]:
trainer.fit(model, datamodule=data_module)

In [None]:
trainer.test(datamodule=data_module)

In [None]:
if not model.subNN: print(model.rnn.parent)