In [694]:
!which python
!whoami 
!hostname
!pwd

/home/farshed.abdukhakimov/miniconda3/envs/main/bin/python
farshed.abdukhakimov
srv-01
/home/farshed.abdukhakimov/projects/twin-polyak/experiments


In [702]:
%load_ext autoreload
%autoreload 2

import os
import datetime
import time
from collections import defaultdict
import pickle
import csv

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import numpy as np

import matplotlib.pyplot as plt

import utils
from utils import moving_average
from solve_binary_libsvm import solve_binary_libsvm

import sklearn
import sklearn.datasets
from sklearn.model_selection import train_test_split

import scipy
import svmlight_loader

import lightning as L
import torchmetrics

from pt_methods import TwinPolyakMA, Momo
import sps
from utils import SimpleDataset

from custom_logger import DBLogger
from base_module import BaseTrainingModule

from dotenv import load_dotenv
load_dotenv()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


True

In [699]:
device = torch.device('cpu')
device

device(type='cpu')

In [700]:
data_dir: str = os.getenv("LIBSVM_DIR")

data, target = sklearn.datasets.load_svmlight_file(f'{data_dir}/abalone_scale')


data, target

(<4177x8 sparse matrix of type '<class 'numpy.float64'>'
 	with 32080 stored elements in Compressed Sparse Row format>,
 array([15.,  7.,  9., ...,  9., 10., 12.]))

In [701]:
class AbaloneDataModule(L.LightningDataModule):
    
    def __init__(self, data_dir: str = os.getenv("LIBSVM_DIR"), batch_size: int = 32):
        super().__init__()
        
        data, target = sklearn.datasets.load_svmlight_file(f'{data_dir}/abalone_scale')
        data = sklearn.preprocessing.normalize(data, norm='l2', axis=1)
        self.train_data, self.val_data, self.train_target, self.val_target = train_test_split(data, target, test_size=0.2, random_state=0)

        self.batch_size: int = batch_size
        self.num_features: int = data.shape[1]

    def setup(self, stage: str):
        
        if stage  == 'fit':
            self.train_dataset = SimpleDataset(self.train_data, self.train_target)
            self.val_dataset = SimpleDataset(self.val_data, self.val_target)
        if stage in ('test', 'predict'):
            self.val_dataset = SimpleDataset(self.val_data, self.val_target)
            
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self._sparse_collate, num_workers=2, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self._sparse_collate, num_workers=2, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self._sparse_collate, num_workers=2, shuffle=False)
    
    
    def _sparse_coo_to_tensor(self, coo):
        values = coo.data
        indices = np.vstack((coo.row, coo.col))
        shape = coo.shape
        
        i = torch.LongTensor(indices)
        v = torch.FloatTensor(values)
        s = torch.Size(shape)
        
        return torch.sparse_coo_tensor(i, v, s)

    def _sparse_collate(self, batch):
        xs, ys = zip(*batch)
        
        xs = scipy.sparse.vstack(xs).tocoo()
        xs = self._sparse_coo_to_tensor(xs)
        
        return xs, torch.tensor(ys, dtype=torch.float32)



In [705]:
optimizers_dict = {
    'Adam': torch.optim.Adam,
    'Momo': Momo,
    'SPS': sps.Sps,
    'Adagrad': torch.optim.Adagrad,
}

class RegressionModel(nn.Module):
    
    def __init__(self, input_dim: int):
        super().__init__()
        self.linear = nn.Linear(in_features=input_dim, out_features=1)
        
    def forward(self, x):
        return self.linear(x).squeeze(-1)

class AbaloneRegressor(BaseTrainingModule):
    
    def __init__(self, input_dim: int, config: dict):
        
        self.input_dim = input_dim

        self.save_hyperparameters(
            {
                'dataset': 'abalone_scale',
                'task': 'regression',
                'model': 'linear',
                'config': config,
            }
        )
         
        super().__init__(config)

    def build_model(self):
        return RegressionModel(self.input_dim)
    
    def define_loss_fn(self):
        return F.mse_loss
    
    def define_val_acc_metric(self):
        return None
    
    def unpack_batch(self, batch):
        x, y = batch
        return x.to_dense(), y
        

In [706]:
from lightning.pytorch import Trainer, seed_everything, loggers

seed = 0

config = {
    'seed': seed,
    'max_epochs': 50,
    'batch_size': 128,
    'optimizer': 'STP',
    'optimizer_hparams': {
        'beta': 0.0,
    },
}

data_module = AbaloneDataModule(batch_size=config['batch_size'])
data_module.setup('fit')

seed_everything(seed, workers=True)

model = AbaloneRegressor(input_dim=data_module.num_features, config=config)

db_logger_callback = DBLogger()
csv_logger = loggers.CSVLogger(
    save_dir=f"logs/{model.hparams['dataset']}",
    version=datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    )

trainer = L.Trainer(
    max_epochs=config['max_epochs'], 
    logger=[csv_logger], 
    callbacks=[db_logger_callback], 
    accelerator='cpu',
    log_every_n_steps=min(len(data_module.train_dataloader()), 50)
    )

trainer.fit(model=model, datamodule=data_module)

Seed set to 0
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: False
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
HPU available: False, using: 0 HPUs

  | Name    | Type            | Params | Mode 
----------------------------------------------------
0 | model_x | RegressionModel | 9      | train
1 | model_y | RegressionModel | 9      | train
----------------------------------------------------
18        Trainable params
0         Non-trainable params
18        Total params
0.000     Total estimated model params size (MB)
4         Modules in train mode
0         Modules in eval 

Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=50` reached.
`Trainer.fit` stopped: `max_epochs=50` reached.


12