In [34]:
from lowrank.config_utils.config_parser import ConfigParser
from lowrank.layers.dense_layer import DenseLayer
from lowrank.layers.dynamic_low_rank import DynamicLowRankLayer
from lowrank.training.neural_network import FeedForward
from lowrank.training.trainer import Trainer
from lowrank.optimizers.meta_optimizer import MetaOptimizer
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import os
import sys
import argparse
import time
import torch

In [47]:
def params_in_dense(input, out):
	return input * out + out

def params_in_lowrank(input, out, rank):
	return input * rank + rank * rank + rank * out + out

def params_in_basedynamic(rank):
	layer_sizes = [(784, 128), (128, 64), (64, 10)]
	params = 0
	for x, y in layer_sizes[:-1]:
		params += params_in_lowrank(x, y, rank)
	params += params_in_dense(layer_sizes[-1][0], layer_sizes[-1][1])
	return params

total_params_dense = sum(params_in_dense(x,y) for x,y in [(784, 128), (128, 64), (64, 10)])
print(total_params_dense)

params_in_basedynamic(20)

109386


23722

In [36]:
model = FeedForward.create_from_config("config_files/basedense.toml")
trainer = Trainer.create_from_model(model)

In [37]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
])

# Transformation for test data (without random transformations)
test_transform = transforms.Compose([
    transforms.ToTensor()
])

# Load the MNIST dataset with the defined transformations
train_data = datasets.MNIST(root='data', train=True, transform=train_transform, download=True)
test_data = datasets.MNIST(root='data', train=False, transform=test_transform, download=True)

# Create training and test dataloaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)  # Usually, shuffling is not needed for test data


In [38]:
# train the model
trainer.train(train_loader, test_loader, patience = 30)

Epoch 1/30: 100%|██████████| 938/938 [00:06<00:00, 138.90it/s]


Epoch [1/30], Validation Accuracy: 65.30%, Validation Loss: 1.8066


Epoch 2/30: 100%|██████████| 938/938 [00:06<00:00, 153.84it/s]


Epoch [2/30], Validation Accuracy: 74.04%, Validation Loss: 1.7202


Epoch 3/30: 100%|██████████| 938/938 [00:06<00:00, 143.57it/s]


Epoch [3/30], Validation Accuracy: 75.18%, Validation Loss: 1.7086


Epoch 4/30: 100%|██████████| 938/938 [00:06<00:00, 138.95it/s]


Epoch [4/30], Validation Accuracy: 77.31%, Validation Loss: 1.6876


Epoch 5/30: 100%|██████████| 938/938 [00:07<00:00, 126.35it/s]


Epoch [5/30], Validation Accuracy: 90.91%, Validation Loss: 1.5516


Epoch 6/30: 100%|██████████| 938/938 [00:06<00:00, 137.07it/s]


Epoch [6/30], Validation Accuracy: 92.75%, Validation Loss: 1.5336


Epoch 7/30: 100%|██████████| 938/938 [00:06<00:00, 144.75it/s]


Epoch [7/30], Validation Accuracy: 93.14%, Validation Loss: 1.5294


Epoch 8/30: 100%|██████████| 938/938 [00:07<00:00, 122.13it/s]


Epoch [8/30], Validation Accuracy: 92.73%, Validation Loss: 1.5335


Epoch 9/30: 100%|██████████| 938/938 [00:07<00:00, 121.04it/s]


Epoch [9/30], Validation Accuracy: 93.67%, Validation Loss: 1.5243


Epoch 10/30: 100%|██████████| 938/938 [00:06<00:00, 144.56it/s]


Epoch [10/30], Validation Accuracy: 92.51%, Validation Loss: 1.5359


Epoch 11/30: 100%|██████████| 938/938 [00:06<00:00, 146.44it/s]


Epoch [11/30], Validation Accuracy: 92.89%, Validation Loss: 1.5322


Epoch 12/30: 100%|██████████| 938/938 [00:06<00:00, 147.10it/s]


Epoch [12/30], Validation Accuracy: 93.26%, Validation Loss: 1.5284


Epoch 13/30: 100%|██████████| 938/938 [00:06<00:00, 138.81it/s]


Epoch [13/30], Validation Accuracy: 93.71%, Validation Loss: 1.5239


Epoch 14/30: 100%|██████████| 938/938 [00:07<00:00, 132.08it/s]


Epoch [14/30], Validation Accuracy: 92.62%, Validation Loss: 1.5346


Epoch 15/30: 100%|██████████| 938/938 [00:06<00:00, 140.01it/s]


Epoch [15/30], Validation Accuracy: 92.20%, Validation Loss: 1.5389


Epoch 16/30: 100%|██████████| 938/938 [00:06<00:00, 136.53it/s]


Epoch [16/30], Validation Accuracy: 93.65%, Validation Loss: 1.5247


Epoch 17/30: 100%|██████████| 938/938 [00:06<00:00, 145.75it/s]


Epoch [17/30], Validation Accuracy: 94.07%, Validation Loss: 1.5201


Epoch 18/30: 100%|██████████| 938/938 [00:06<00:00, 146.31it/s]


Epoch [18/30], Validation Accuracy: 92.89%, Validation Loss: 1.5317


Epoch 19/30: 100%|██████████| 938/938 [00:07<00:00, 122.55it/s]


Epoch [19/30], Validation Accuracy: 94.42%, Validation Loss: 1.5167


Epoch 20/30: 100%|██████████| 938/938 [00:07<00:00, 122.32it/s]


Epoch [20/30], Validation Accuracy: 92.25%, Validation Loss: 1.5384


Epoch 21/30: 100%|██████████| 938/938 [00:07<00:00, 121.17it/s]


Epoch [21/30], Validation Accuracy: 94.25%, Validation Loss: 1.5182


Epoch 22/30: 100%|██████████| 938/938 [00:07<00:00, 119.58it/s]


Epoch [22/30], Validation Accuracy: 92.08%, Validation Loss: 1.5400


Epoch 23/30: 100%|██████████| 938/938 [00:07<00:00, 124.49it/s]


Epoch [23/30], Validation Accuracy: 94.02%, Validation Loss: 1.5206


Epoch 24/30: 100%|██████████| 938/938 [00:07<00:00, 121.94it/s]


Epoch [24/30], Validation Accuracy: 92.76%, Validation Loss: 1.5330


Epoch 25/30: 100%|██████████| 938/938 [00:06<00:00, 138.79it/s]


Epoch [25/30], Validation Accuracy: 94.89%, Validation Loss: 1.5120


Epoch 26/30: 100%|██████████| 938/938 [00:06<00:00, 141.09it/s]


Epoch [26/30], Validation Accuracy: 90.92%, Validation Loss: 1.5513


Epoch 27/30: 100%|██████████| 938/938 [00:06<00:00, 135.05it/s]


Epoch [27/30], Validation Accuracy: 93.76%, Validation Loss: 1.5231


Epoch 28/30: 100%|██████████| 938/938 [00:06<00:00, 138.46it/s]


Epoch [28/30], Validation Accuracy: 91.82%, Validation Loss: 1.5425


Epoch 29/30: 100%|██████████| 938/938 [00:06<00:00, 143.75it/s]


Epoch [29/30], Validation Accuracy: 90.97%, Validation Loss: 1.5516


Epoch 30/30: 100%|██████████| 938/938 [00:06<00:00, 139.09it/s]


Epoch [30/30], Validation Accuracy: 91.74%, Validation Loss: 1.5434


(FeedForward(
   (layers): ModuleList(
     (0): Flatten(start_dim=1, end_dim=-1)
     (1-2): 2 x DenseLayer(
       (activation): ReLU()
     )
     (3): DenseLayer(
       (activation): Identity()
     )
     (4): Softmax(dim=1)
   )
 ),
 [{'epoch': 1,
   'train_loss': 1.850622563346871,
   'val_accuracy': 0.653,
   'val_loss': 1.8065910779746475},
  {'epoch': 2,
   'train_loss': 1.7420614158420928,
   'val_accuracy': 0.7404,
   'val_loss': 1.720169954239183},
  {'epoch': 3,
   'train_loss': 1.7100022630905038,
   'val_accuracy': 0.7518,
   'val_loss': 1.7086406885438663},
  {'epoch': 4,
   'train_loss': 1.7046074546984773,
   'val_accuracy': 0.7731,
   'val_loss': 1.6876134583904485},
  {'epoch': 5,
   'train_loss': 1.5881324701471877,
   'val_accuracy': 0.9091,
   'val_loss': 1.5516288348823597},
  {'epoch': 6,
   'train_loss': 1.5423779121594134,
   'val_accuracy': 0.9275,
   'val_loss': 1.5336408911237291},
  {'epoch': 7,
   'train_loss': 1.5411658972056943,
   'val_accuracy': 0.

In [44]:
model2 = FeedForward.create_from_config("config_files/basedynamiclowrank.toml")
trainer2 = Trainer.create_from_model(model2)

In [45]:
trainer2.train(train_loader, test_loader, patience = 30)

Epoch 1/30: 100%|██████████| 938/938 [00:07<00:00, 132.76it/s]


Epoch [1/30], Validation Accuracy: 61.94%, Validation Loss: 1.2728


Epoch 2/30: 100%|██████████| 938/938 [00:06<00:00, 144.76it/s]


Epoch [2/30], Validation Accuracy: 88.79%, Validation Loss: 0.3708


Epoch 3/30: 100%|██████████| 938/938 [00:06<00:00, 144.00it/s]


Epoch [3/30], Validation Accuracy: 87.52%, Validation Loss: 0.4446


Epoch 4/30: 100%|██████████| 938/938 [00:06<00:00, 144.85it/s]


Epoch [4/30], Validation Accuracy: 85.55%, Validation Loss: 0.4923


Epoch 5/30: 100%|██████████| 938/938 [00:06<00:00, 146.68it/s]


Epoch [5/30], Validation Accuracy: 62.61%, Validation Loss: 1.3850


Epoch 6/30: 100%|██████████| 938/938 [00:06<00:00, 147.04it/s]


Epoch [6/30], Validation Accuracy: 89.84%, Validation Loss: 0.3280


Epoch 7/30: 100%|██████████| 938/938 [00:06<00:00, 144.64it/s]


Epoch [7/30], Validation Accuracy: 92.62%, Validation Loss: 0.2384


Epoch 8/30: 100%|██████████| 938/938 [00:07<00:00, 133.99it/s]


Epoch [8/30], Validation Accuracy: 93.16%, Validation Loss: 0.2276


Epoch 9/30: 100%|██████████| 938/938 [00:06<00:00, 137.11it/s]


Epoch [9/30], Validation Accuracy: 93.06%, Validation Loss: 0.2285


Epoch 10/30: 100%|██████████| 938/938 [00:06<00:00, 134.72it/s]


Epoch [10/30], Validation Accuracy: 88.04%, Validation Loss: 0.4279


Epoch 11/30: 100%|██████████| 938/938 [00:06<00:00, 138.89it/s]


Epoch [11/30], Validation Accuracy: 90.00%, Validation Loss: 0.3328


Epoch 12/30: 100%|██████████| 938/938 [00:06<00:00, 139.10it/s]


Epoch [12/30], Validation Accuracy: 90.42%, Validation Loss: 0.3181


Epoch 13/30: 100%|██████████| 938/938 [00:06<00:00, 138.86it/s]


Epoch [13/30], Validation Accuracy: 90.28%, Validation Loss: 0.3243


Epoch 14/30: 100%|██████████| 938/938 [00:06<00:00, 140.60it/s]


Epoch [14/30], Validation Accuracy: 90.13%, Validation Loss: 0.3397


Epoch 15/30: 100%|██████████| 938/938 [00:06<00:00, 136.22it/s]


Epoch [15/30], Validation Accuracy: 84.17%, Validation Loss: 0.5804


Epoch 16/30: 100%|██████████| 938/938 [00:07<00:00, 132.70it/s]


Epoch [16/30], Validation Accuracy: 84.03%, Validation Loss: 0.5320


Epoch 17/30: 100%|██████████| 938/938 [00:06<00:00, 135.13it/s]


Epoch [17/30], Validation Accuracy: 88.94%, Validation Loss: 0.3676


Epoch 18/30: 100%|██████████| 938/938 [00:07<00:00, 132.51it/s]


Epoch [18/30], Validation Accuracy: 79.74%, Validation Loss: 0.6710


Epoch 19/30: 100%|██████████| 938/938 [00:07<00:00, 132.63it/s]


Epoch [19/30], Validation Accuracy: 88.16%, Validation Loss: 0.4113


Epoch 20/30: 100%|██████████| 938/938 [00:07<00:00, 133.75it/s]


Epoch [20/30], Validation Accuracy: 40.21%, Validation Loss: 1.7376


Epoch 21/30: 100%|██████████| 938/938 [00:06<00:00, 136.67it/s]


Epoch [21/30], Validation Accuracy: 90.67%, Validation Loss: 0.3349


Epoch 22/30: 100%|██████████| 938/938 [00:06<00:00, 134.80it/s]


Epoch [22/30], Validation Accuracy: 89.15%, Validation Loss: 0.3652


Epoch 23/30: 100%|██████████| 938/938 [00:06<00:00, 134.52it/s]


Epoch [23/30], Validation Accuracy: 84.87%, Validation Loss: 0.5217


Epoch 24/30: 100%|██████████| 938/938 [00:06<00:00, 134.81it/s]


Epoch [24/30], Validation Accuracy: 89.42%, Validation Loss: 0.3465


Epoch 25/30: 100%|██████████| 938/938 [00:07<00:00, 130.58it/s]


Epoch [25/30], Validation Accuracy: 88.63%, Validation Loss: 0.3656


Epoch 26/30: 100%|██████████| 938/938 [00:06<00:00, 137.57it/s]


Epoch [26/30], Validation Accuracy: 88.66%, Validation Loss: 0.4325


Epoch 27/30: 100%|██████████| 938/938 [00:06<00:00, 135.62it/s]


Epoch [27/30], Validation Accuracy: 90.10%, Validation Loss: 0.3346


Epoch 28/30: 100%|██████████| 938/938 [00:07<00:00, 125.26it/s]


Epoch [28/30], Validation Accuracy: 82.09%, Validation Loss: 0.5592


Epoch 29/30: 100%|██████████| 938/938 [00:06<00:00, 138.99it/s]


Epoch [29/30], Validation Accuracy: 88.46%, Validation Loss: 0.3828


Epoch 30/30: 100%|██████████| 938/938 [00:06<00:00, 135.53it/s]


Epoch [30/30], Validation Accuracy: 87.02%, Validation Loss: 0.4289


(FeedForward(
   (layers): ModuleList(
     (0): Flatten(start_dim=1, end_dim=-1)
     (1-2): 2 x DynamicLowRankLayer(
       (activation): ReLU()
     )
     (3): DenseLayer(
       (activation): Identity()
     )
   )
 ),
 [{'epoch': 1,
   'train_loss': 1.0565884807375447,
   'val_accuracy': 0.6194,
   'val_loss': 1.2727796708702281},
  {'epoch': 2,
   'train_loss': 1.0402665419190296,
   'val_accuracy': 0.8879,
   'val_loss': 0.3708132586922425},
  {'epoch': 3,
   'train_loss': 1.382035160465027,
   'val_accuracy': 0.8752,
   'val_loss': 0.4446459924149665},
  {'epoch': 4,
   'train_loss': 1.327798548688703,
   'val_accuracy': 0.8555,
   'val_loss': 0.49230400548808895},
  {'epoch': 5,
   'train_loss': 0.4080968014578194,
   'val_accuracy': 0.6261,
   'val_loss': 1.3850463064992504},
  {'epoch': 6,
   'train_loss': 6.178895355724513,
   'val_accuracy': 0.8984,
   'val_loss': 0.32795785312914544},
  {'epoch': 7,
   'train_loss': 0.322584460547039,
   'val_accuracy': 0.9262,
   'val_l