In [24]:
from datasets import fuzzy_boolean_dataset
import torch.nn as nn
from models import simple_machine, cls_machine
import torch
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import r2_score

In [25]:
batch_size = 128

In [26]:
dataset = fuzzy_boolean_dataset.FuzzyBooleanDataset('data/fbd.npy')
train_len = int( 0.8 * len(dataset) ) 
val_len = len(dataset) - train_len
train_set, val_set = torch.utils.data.random_split(dataset, [train_len, val_len])
trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, \
    shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, \
    shuffle=True, num_workers=2)

In [27]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = simple_machine.SimpleEncoder(32, 64, 2, 2, 4, 8)
        self.cls_net = cls_machine.ClsMachine(self.backbone, 30, 1)
        self.reg_head = nn.Linear(32, 1)
    
    def forward(self, x):
        temp = x 
        temp = self.cls_net(temp)
        temp = self.reg_head(temp)
        return temp

In [28]:
device = torch.device('cuda')
model = Model().to(device)

In [32]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [33]:
for epoch in range(6):
    running_loss = []
    pbar = tqdm(trainloader)
    for X, Y in pbar:
        X = X.to(device).unsqueeze(2).float()
        Y = Y.to(device).float()[:, :20]
        optimizer.zero_grad()
        pred = model(X).squeeze(2)[:, :20]
        loss = criterion(pred, Y)
        loss.backward()
        optimizer.step()
        pbar.set_description(f"Epoch {epoch}")
        running_loss.append(loss.item())
        pbar.set_postfix(loss = sum(running_loss) / len(running_loss))
    
    with torch.no_grad():
        pbar = tqdm(valloader)
        scores = []
        for X, Y in pbar:
            X = X.to(device).unsqueeze(2).float().detach()
            Y = Y.to(device).float()[:, :20]
            pred = model(X).squeeze(2)[:, :20]
            score = r2_score(pred.detach().cpu(), Y.detach().cpu())
            scores.append(score)
            pbar.set_postfix(score = sum(scores) / len(scores))

Epoch 0: 100%|██████████| 1022/1022 [01:01<00:00, 16.53it/s, loss=0.012] 
100%|██████████| 256/256 [00:06<00:00, 41.22it/s, score=0.733]
Epoch 1: 100%|██████████| 1022/1022 [01:01<00:00, 16.62it/s, loss=0.00148]
100%|██████████| 256/256 [00:06<00:00, 41.07it/s, score=0.863]
Epoch 2: 100%|██████████| 1022/1022 [01:01<00:00, 16.55it/s, loss=0.000788]
100%|██████████| 256/256 [00:06<00:00, 40.87it/s, score=0.911]
Epoch 3: 100%|██████████| 1022/1022 [01:01<00:00, 16.62it/s, loss=0.00055] 
100%|██████████| 256/256 [00:06<00:00, 41.13it/s, score=0.94] 
Epoch 4: 100%|██████████| 1022/1022 [01:01<00:00, 16.57it/s, loss=0.000419]
100%|██████████| 256/256 [00:06<00:00, 41.07it/s, score=0.947]
Epoch 5: 100%|██████████| 1022/1022 [01:01<00:00, 16.61it/s, loss=0.000338]
100%|██████████| 256/256 [00:06<00:00, 40.95it/s, score=0.966]


In [41]:
# backup = model
#model = backup
optimizer = torch.optim.Adam([model.cls_net.cls_embeddings], lr=1e-3)

In [43]:
for epoch in range(2):
    running_loss = []
    pbar = tqdm(trainloader)
    for X, Y in pbar:
        X = X.to(device).unsqueeze(2).float()
        Y = Y.to(device).float()[:, 20:]
        optimizer.zero_grad()
        pred = model(X).squeeze(2)[:, 20:]
        loss = criterion(pred, Y)
        loss.backward()
        optimizer.step()
        pbar.set_description(f"Epoch {epoch}")
        running_loss.append(loss.item())
        pbar.set_postfix(loss = sum(running_loss) / len(running_loss))
    
    with torch.no_grad():
        pbar = tqdm(valloader)
        scores = []
        for X, Y in pbar:
            X = X.to(device).unsqueeze(2).float().detach()
            Y = Y.to(device).float()[:, 20:]
            pred = model(X).squeeze(2)[:, 20:]
            score = r2_score(pred.detach().cpu(), Y.detach().cpu())
            scores.append(score)
            pbar.set_postfix(score = sum(scores) / len(scores))

Epoch 0: 100%|██████████| 1022/1022 [00:52<00:00, 19.62it/s, loss=0.00653]
100%|██████████| 256/256 [00:06<00:00, 40.96it/s, score=0.48] 
Epoch 1: 100%|██████████| 1022/1022 [00:52<00:00, 19.53it/s, loss=0.00616]
100%|██████████| 256/256 [00:06<00:00, 40.88it/s, score=0.545]
