In [1]:
import torch
import os
import sys

sys.path.insert(1, os.path.join(sys.path[0], 'src'))

processed_dir = '/data/pengmiao/PaCKD_0/processed'

train_loader = torch.load(os.path.join(processed_dir, f"bc-3.train.pt"))
test_loader = torch.load(os.path.join(processed_dir, f"bc-3.test.pt"))

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
from utils import select_tch

device = torch.device(f"cuda:2" if torch.cuda.is_available() else "cpu")

model = select_tch('r')
model = model.to(device)

In [3]:
sigmoid = torch.nn.Sigmoid()
import torch.nn.functional as F

def train(ep, train_loader, model_save_path):
    global steps
    epoch_loss = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):#d,t: (torch.Size([64, 1, 784]),64)        
        optimizer.zero_grad()
        output = sigmoid(model(data))
        loss = F.binary_cross_entropy(output, target, reduction='mean')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss/=len(train_loader)
    return epoch_loss


def test(test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = sigmoid(model(data))
            test_loss += F.binary_cross_entropy(output, target, reduction='mean').item()
            thresh=0.5
            output_bin=(output>=thresh)*1
            correct+=(output_bin&target.int()).sum()
        test_loss /=  len(test_loader)
        return test_loss

In [4]:
import csv

def run_epoch(epochs, early_stop, loading, model_save_path, train_loader, test_loader, tsv_path, model):
    if loading==True:
        model.load_state_dict(torch.load(model_save_path))
        print("-------------Model Loaded------------")
        
    best_loss=0
    early_stop = early_stop
    curr_early_stop = early_stop

    metrics_data = []

    for epoch in range(epochs):

        train_loss=train(epoch,train_loader,model_save_path)
        test_loss=test(test_loader)
        print((f"Epoch: {epoch+1} - loss: {train_loss:.10f} - test_loss: {test_loss:.10f}"))
        
        if epoch == 0:
            best_loss=test_loss
        if test_loss<=best_loss:
            torch.save(model.state_dict(), model_save_path)    
            best_loss=test_loss
            print("-------- Save Best Model! --------")
            curr_early_stop = early_stop
        else:
            curr_early_stop -= 1
            print("Early Stop Left: {}".format(curr_early_stop))
        if curr_early_stop == 0:
            print("-------- Early Stop! --------")
            break

        metrics_data.append([epoch+1, train_loss, test_loss])

    with open(tsv_path, 'w') as file:
        writer = csv.writer(file, delimiter='\t')
        writer.writerow(['Epoch', 'Train Loss', 'Test Loss'])
        writer.writerows(metrics_data)

Baseline Teacher R Training No Cluster No KD, lr = 0.0005

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

from data_loader import init_dataloader

device = torch.device(f"cuda:2" if torch.cuda.is_available() else "cpu")

optimizer = optim.Adam(model.parameters(), lr=0.0005)
scheduler = StepLR(optimizer, step_size=20, gamma=0.1)

epochs = 50
early_stop = 15
loading = False
model_save_path = '/data/pengmiao/PaCKD_1/model/bc-3.teacher.lr.5.r.pth'
tsv_path = '/data/pengmiao/PaCKD_1/model/bc-3.teacher.lr.5.r.tsv'

init_dataloader('2')

run_epoch(epochs, early_stop, loading, model_save_path, train_loader, test_loader, tsv_path, model)

Epoch: 1 - loss: 0.2245028114 - test_loss: 0.2189558178
-------- Save Best Model! --------
Epoch: 2 - loss: 0.2252602793 - test_loss: 0.2233707790
Early Stop Left: 14
Epoch: 3 - loss: 0.2224036560 - test_loss: 0.2943823651
Early Stop Left: 13
Epoch: 4 - loss: 0.2222977377 - test_loss: 0.2214530328
Early Stop Left: 12
Epoch: 5 - loss: 0.2219941840 - test_loss: 0.2131103262
-------- Save Best Model! --------
Epoch: 6 - loss: 0.2228332990 - test_loss: 0.2350406584
Early Stop Left: 14
Epoch: 8 - loss: 0.2237195330 - test_loss: 0.2279378600
Early Stop Left: 14
Epoch: 9 - loss: 0.2222437938 - test_loss: 0.2085595350
Early Stop Left: 13
Epoch: 10 - loss: 0.2227702274 - test_loss: 0.2058029802
-------- Save Best Model! --------
Epoch: 11 - loss: 0.2217662464 - test_loss: 0.2221409752
Early Stop Left: 14
Epoch: 12 - loss: 0.2220092879 - test_loss: 0.2046922247
-------- Save Best Model! --------
Epoch: 13 - loss: 0.2224930762 - test_loss: 0.2123219064
Early Stop Left: 14
Epoch: 14 - loss: 0.2213

In [None]:
test_df = torch.load(os.path.join(processed_dir, 'bc-3.df.pt'))

res = run_val(test_loader, test_df, 'bc-3.txt.xz', model_save_path) 