In [3]:
import torch
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset
%matplotlib inline

from pclib.nn.models import FCClassifierUs
from pclib.optim.train import train
from pclib.optim.eval import track_vfe, accuracy
from pclib.utils.customdataset import PreloadedDataset

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [5]:
with open('../Datasets/mini_shakespeare.txt', 'r', encoding='utf8') as f:
    text = f.read()
chars = sorted(list(set(text)))
stoi = {c: i+1 for i, c in enumerate(chars)}
itos = {i+1: c for i, c in enumerate(chars)}
encode = lambda x: [stoi[ch] for ch in x]
decode = lambda x: ''.join([itos[i] for i in x])

In [6]:
BATCH_SIZE = 64
BLOCK_SIZE = 32

In [7]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return 500
    def __getitem__(self):
        hi_idx = torch.randint(1, len(self.data), (1,)).item()
        lo_idx = max(0, hi_idx - BLOCK_SIZE)
        x = self.data[lo_idx:hi_idx]
        y = self.data[hi_idx]
        return x, y
    def apply_transform(self):
        pass

split1 = int(len(words) * 0.8)
split2 = int(len(words) * 0.9)


import random
seed = 42
random.seed(seed)
random.shuffle(words)

Xs, Ys = BuildDataset(words, device)
train_dataset = CustomDataset(Xs[:split1], Ys[:split1])
val_dataset = CustomDataset(Xs[split1:split2], Ys[split1:split2])
test_dataset = CustomDataset(Xs[split2:], Ys[split2:])

In [8]:
INPUT_SHAPE = BLOCK_SIZE
NUM_CLASSES = vocab_len
torch.manual_seed(42)

model_name = 'FCClUs-BS5'
model = FCClassifierUs(
    in_features = INPUT_SHAPE, 
    num_classes = NUM_CLASSES,
    hidden_sizes = [200, 200],
    bias=True, 
    symmetric=True, 
    precision_weighted=False,
    actv_fn=F.tanh,
    steps=100,
    gamma=0.34,
    ).to(device)

In [9]:
NUM_EPOCHS = 10
BATCH_SIZE = 256

log_dir = f'examples/names/logs/{model_name}'

train(
    model, 
    train_dataset, 
    val_dataset, 
    NUM_EPOCHS, 
    lr=0.001,
    c_lr=0.0,
    batch_size=BATCH_SIZE,
    reg_coeff=0.02,
    optim='AdamW',
    save_best=False,
    log_dir=log_dir,
)
NUM_EPOCHS = 5
train(
    model, 
    train_dataset, 
    val_dataset, 
    NUM_EPOCHS, 
    lr=0.0001,
    c_lr=0.01,
    batch_size=BATCH_SIZE,
    reg_coeff=0.02,
    optim='AdamW',
    save_best=False,
    log_dir=log_dir,
)

                                                                                                             