In [6]:
import torch
import os
import sys

processed_dir = '/data/pengmiao/PaCKD_0/processed'

train_loader = torch.load(os.path.join(processed_dir, f"bc-3.train.pt"))
test_loader = torch.load(os.path.join(processed_dir, f"bc-3.test.pt"))

In [26]:
from utils import select_tch

model = select_tch('m')
model = model.to(device)

In [34]:
sigmoid = torch.nn.Sigmoid()
import torch.nn.functional as F

def train(ep, train_loader, model_save_path):
    global steps
    epoch_loss = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):#d,t: (torch.Size([64, 1, 784]),64)        
        optimizer.zero_grad()
        output = sigmoid(model(data))
        loss = F.binary_cross_entropy(output, target, reduction='mean')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss/=len(train_loader)
    return epoch_loss


def test(test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = sigmoid(model(data))
            test_loss += F.binary_cross_entropy(output, target, reduction='mean').item()
            thresh=0.5
            output_bin=(output>=thresh)*1
            correct+=(output_bin&target.int()).sum()
        test_loss /=  len(test_loader)
        return test_loss

In [35]:
import csv

def run_epoch(epochs, early_stop, loading, model_load_path, model_save_path, train_loader, test_loader, tsv_path, model):
    if loading==True:
        model.load_state_dict(torch.load(model_load_path))
        print("-------------Model Loaded------------")
        
    best_loss=0
    early_stop = early_stop
    curr_early_stop = early_stop

    metrics_data = []

    for epoch in range(epochs):

        train_loss=train(epoch,train_loader,model_save_path)
        test_loss=test(test_loader)
        print((f"Epoch: {epoch+1} - loss: {train_loss:.10f} - test_loss: {test_loss:.10f}"))
        
        if epoch == 0:
            best_loss=test_loss
        if test_loss<=best_loss:
            torch.save(model.state_dict(), model_save_path)    
            best_loss=test_loss
            print("-------- Save Best Model! --------")
            curr_early_stop = early_stop
        else:
            curr_early_stop -= 1
            print("Early Stop Left: {}".format(curr_early_stop))
        if curr_early_stop == 0:
            print("-------- Early Stop! --------")
            break

        metrics_data.append([epoch+1, train_loss, test_loss])

    with open(tsv_path, 'w') as file:
        writer = csv.writer(file, delimiter='\t')
        writer.writerow(['Epoch', 'Train Loss', 'Test Loss'])
        writer.writerows(metrics_data)

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

from data_loader import init_dataloader

device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")

optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = StepLR(optimizer, step_size=20, gamma=0.1)

epochs = 50
early_stop = 15
loading = True
model_load_path = '/data/pengmiao/PaCKD_2/model/i/bc-3.teacher_2.m.pth'
model_save_path = '/data/pengmiao/PaCKD_1/model/bc-3.teacher.lr.1.m.pth'
tsv_path = '/data/pengmiao/PaCKD_1/model/bc-3.teacher.lr.1.m.tsv'

init_dataloader('0')

run_epoch(epochs, early_stop, loading, model_load_path, model_save_path, train_loader, test_loader, tsv_path, model)

-------------Model Loaded------------
Epoch: 1 - loss: 0.2342319620 - test_loss: 0.2266077912
-------- Save Best Model! --------
Epoch: 2 - loss: 0.2313984578 - test_loss: 0.2288997499
Early Stop Left: 14
Epoch: 3 - loss: 0.2303736973 - test_loss: 0.2261069042
-------- Save Best Model! --------
Epoch: 4 - loss: 0.2284807163 - test_loss: 0.2225959873
-------- Save Best Model! --------
Epoch: 5 - loss: 0.2279223187 - test_loss: 0.2268306982
Early Stop Left: 14
Epoch: 6 - loss: 0.2267487994 - test_loss: 0.2227692970
Early Stop Left: 13
Epoch: 7 - loss: 0.2264660713 - test_loss: 0.2295580302
Early Stop Left: 12
Epoch: 8 - loss: 0.2265785517 - test_loss: 0.2226403109
Early Stop Left: 11
Epoch: 9 - loss: 0.2256653242 - test_loss: 0.2270950673
Early Stop Left: 10
Epoch: 10 - loss: 0.2261850593 - test_loss: 0.2312971368
Early Stop Left: 9
Epoch: 11 - loss: 0.2255497147 - test_loss: 0.2297412506
Early Stop Left: 8
Epoch: 12 - loss: 0.2270941808 - test_loss: 0.2238346085
Early Stop Left: 7
Epoch

In [None]:
test_df = torch.load(os.path.join(processed_dir, 'bc-3.df.pt'))

res = run_val(test_loader, test_df, 'bc-3.txt.xz', model_save_path)