In [None]:
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [None]:
DEVICE = 'cuda:3' if torch.cuda.is_available() else 'cpu'
print(f'Using {DEVICE}')

In [None]:
class BadNet(nn.Module):
    def __init__(self, token_dim=77, output_dim=1024, normalize=False):
        super(BadNet, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(token_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
        )
        self.normalize = normalize

    def forward(self, x):
        x = self.seq(x)
        if self.normalize:
            x = nn.functional.normalize(x, dim=-1)
        return x

In [None]:
class BabySet(data.Dataset):
    def __init__(self, token_file, feature_file):
        self.features = torch.load(feature_file)
        self.tokens = torch.load(token_file)
        assert(self.features.shape[0] == self.tokens.shape[0])
    
    def __len__(self):
        return self.features.shape[0]
    
    def __getitem__(self, index):
        return  self.tokens[index], self.features[index]

babyset = BabySet( './data/tokens.pt', './data/features.pt')
print(len(babyset))
print(babyset[0])

trainset, validset, testset = data.random_split(babyset, [0.8, 0.1, 0.1])

In [None]:
BATCH_SIZE = 128
NUM_WORKERS = 8

trainloader = data.DataLoader(trainset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
validloader = data.DataLoader(validset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
testloader = data.DataLoader(testset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [None]:
def run_epoch(model, dataloader, optimizer, criterion, metric, train, verbose=False):
    if train:
        model.train()
    else:
        model.eval()
    with torch.set_grad_enabled(train):
        t = tqdm(dataloader)
        losses = torch.zeros(len(t))
        accs = torch.zeros(len(t))
        for i, (X, y) in enumerate(t):
            X, y = X.to(DEVICE), y.to(DEVICE)
            pred = model(X)
            loss = criterion(pred, y)
            if train:
                loss.backward()
                optimizer.step()
            acc = metric(pred, y).mean()
            if verbose:
                t.set_description(f'Loss = {loss:.4f}, Accuracy = {acc * 100:02.2f}%')
            losses[i] = loss.detach().cpu().item()
            accs[i] = acc.detach().cpu().item()
    return losses, accs

In [None]:
model = BadNet(normalize=False).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()
metric = nn.CosineSimilarity()

EPOCHS = 10

train_losses = []
train_accs = []
valid_losses = []
valid_accs = []
for epoch in range(EPOCHS):
    print(f'===== EPOCH {epoch+1:02} =====')
    print('Training...')
    epoch_train_losses, epoch_train_accs = run_epoch(model, trainloader, optimizer, criterion, metric, train=True)
    train_losses.append(epoch_train_losses.mean())
    train_accs.append(epoch_train_accs.mean())
    print(f'Epoch Train Loss = {train_losses[-1]:.4f}, Epoch Train Accuracy = {train_accs[-1] * 100:02.2f}%')

    print('Validating...')
    epoch_valid_losses, epoch_valid_accs = run_epoch(model, validloader, optimizer, criterion, metric, train=False)
    valid_losses.append(epoch_valid_losses.mean())
    valid_accs.append(epoch_valid_accs.mean())
    print(f'Epoch Validation Loss = {valid_losses[-1]:.4f}, Epoch Validation Accuracy = {valid_accs[-1] * 100:02.2f}%')

print(f'===== TESTING =====')
test_losses, test_accs = run_epoch(model, testloader, optimizer, criterion, metric, train=False)
print(f'Test Loss = {test_losses.mean():.4f}, Test Accuracy = {test_accs.mean() * 100:02.2f}%')