# danbooru tagger training

In [1]:
from tagger import *

In [2]:
local = True
reload = False

In [3]:
init_patience = (2,1)[local]

sched_params = {
    .1:{
        'threshold':.01,
        'patience':init_patience,
        'min_lr':.02,
        'factor':.2
    },
    
    .02:{
        'threshold':1e-3,
        'patience':init_patience*5,
        'min_lr':1e-3,
        'factor':.5
    }
}

## data upload

In [4]:
def get_data(train_ds, val_ds, bs):
    
    return (
        DataLoader(
            train_ds, batch_size=bs, shuffle=True),
        DataLoader(
            val_ds, batch_size=(2*bs, len(val_ds))[local])
    )

In [5]:
imgs_path = data_path / 'less-images'

In [6]:
train_dir = imgs_path / 'train'
train_ids = [int(f.stem) for f in train_dir.glob('*')]
train_labels = all_labels[all_labels.id.isin(train_ids)]
train_ds = DanbooruDataset(label_data=train_labels, img_dir=train_dir)

In [7]:
val_dir = imgs_path / 'val'
val_ids = [int(f.stem) for f in val_dir.glob('*')]
val_labels = all_labels[all_labels.id.isin(val_ids)]
val_ds = DanbooruDataset(label_data=val_labels, img_dir=val_dir)

In [8]:
if local:
    bs = 2
    train_samp = get_random_sample(train_ds, 10)
    val_samp = get_random_sample(val_ds, 2)
    train_len, val_len = len(train_samp), len(val_samp)
    train_dl, val_dl = get_data(train_samp, val_samp, bs)
    
else:
    bs = 128
    train_len, val_len = len(train_ds), len(val_ds)
    train_dl, val_dl = get_data(train_ds, val_ds, bs)

## defs

### model and optimizer

In [9]:
def load_model(finetune=False):
    
    model = Tagger().to(dev)
    for param in model.base.parameters():
        param.requires_grad = finetune
    
    if reload:
        fp = Path()/'state-dicts'
        fn = 'model_state_dict.pt'
        try:
            model.load_state_dict(torch.load(
                fp / fn, map_location=dev))
        except RuntimeError:
            model.load_state_dict(torch.load(
                f'backup_{fn}', map_location=dev))
            
    return model

In [10]:
def set_optimizer(model):
    
    optimizer = optim.AdamW(lr=.1, params=filter(
        lambda p: p.requires_grad, model.parameters()))

    if reload:
        fp = Path()/'state-dicts'
        fn = 'opt_state_dict.pt'
        try:
            model.load_state_dict(torch.load(
                fp / fn, map_location=dev))
        except RuntimeError:
            model.load_state_dict(torch.load(
                fp / f'backup_{fn}', map_location=dev))
    
    return optimizer

In [11]:
def save_model(model, opt, save_path=Path()/'state-dicts'):
    
    names = [f'{n}_state_dict.pt' for n in ['model', 'opt']]
    paths = [save_path / n for n in names]
    
    torch.save(model.state_dict(), save_path / names[0])
    torch.save(opt.state_dict(), save_path / names[1])
    
    backups = [save_path / f'backup_{n}' for n in names]
    for i in range(2):
        shutil.copy(paths[i], backups[i])

### lr scheduler and early stopper

In [12]:
class Scheduler(optim.lr_scheduler.ReduceLROnPlateau):
    
    def __init__(self, opt):
        self.last_lr = round(opt.param_groups[0]['lr'], 5)
        super().__init__(
            opt, verbose=True, **sched_params[self.last_lr])
        self.min_lr = self.min_lrs[0]
        
    def load(self, fp=Path()/'state-dicts'):
        fn = 'sched_state_dict.pt'
        try:
            super().load_state_dict(torch.load(
                fp / fn, map_location=dev))
        except RuntimeError:
            super().load_state_dict(torch.load(
                fp / f'backup_{fn}', map_location=dev))
    
    def save(self, fp=Path()/'state-dicts'):
        fn = 'sched_state_dict.pt'
        sd = self.state_dict()
        torch.save(sd, fp / fn)
        torch.save(sd, fp / f'backup_{fn}')
        
    def step(self, val_loss):
        super().step(val_loss)
        self.last_lr = round(self._last_lr[0], 5)
        self.save()
        
    def compare(self):
        return self.last_lr == self.min_lr

In [13]:
class EarlyStopper():

    def __init__(self, patience=init_patience*10, min_delta=1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        
    def step(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss >= self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            print(f'INFO: stopper counter reset')
        else:
            self.counter += 1
            if self.counter == self.patience/2:
                print('INFO: stopper counter halfway through')
            self.early_stop = \
                self.counter >= self.patience

In [14]:
def check_sched(sched, stopper):
    if sched.last_lr >= list(sched_params.keys())[-1]:
        sched = Scheduler(opt)
        print('INFO: scheduler refreshed')
    else:
        stopper = EarlyStopper()
        sched.last_lr=0
        print('INFO: LR scheduling ended')
    return sched, stopper

### training loop

In [15]:
def calc_batch_loss(xb, yb, model, opt=None, loss_func=nn.MSELoss()):

    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item()

In [16]:
def fit(model, opt, sched, train_dl, val_dl, res):
    stopper=None
    while True:
        
        model.train()
        running_loss = 0
        for xb, yb in train_dl:
            xb.to(dev)
            yb.to(dev)
            batch_loss = calc_batch_loss(xb, yb, model, opt)
            running_loss += batch_loss*len(xb)
            save_model(model, opt)
        train_loss = running_loss/train_len

        model.eval()
        with torch.no_grad():
            val_loss = np.sum([
                calc_batch_loss(xb, yb, model)*len(xb)
                    for xb, yb in val_dl
            ]) / val_len
        
        print(f'epoch: {len(res)}', end=' | ')
        print(f'train MSE: {train_loss:.4e}', end=' | ')
        print(f'val MSE: {val_loss:.4e}')
        res = pd.concat([res, pd.DataFrame({
            'train': [train_loss], 'val':[val_loss]})])
        res.to_csv('losses.csv', index=False)
        
        if stopper is None:
            sched.step(val_loss)
            if sched.compare():
                sched, stopper = check_sched(sched, stopper)
        elif not stopper.early_stop:
                stopper.step(val_loss)
        else:
            print('INFO: training stopped')
            break

## run

In [17]:
model = load_model()
opt = set_optimizer(model)
sched = Scheduler(opt)

if reload:
    res=pd.read_csv('losses.csv')
    sched.load()
else:
    res=pd.DataFrame()

In [18]:
fit(model, opt, sched, train_dl, val_dl, res)

epoch: 0 | train MSE: 2.8850e-01 | val MSE: 3.5063e-01
epoch: 1 | train MSE: 2.4237e-01 | val MSE: 2.0487e-01
epoch: 2 | train MSE: 1.6477e-01 | val MSE: 1.8181e-01
epoch: 3 | train MSE: 1.5408e-01 | val MSE: 1.6484e-01
epoch: 4 | train MSE: 1.3909e-01 | val MSE: 1.6485e-01
epoch: 5 | train MSE: 1.2874e-01 | val MSE: 1.4270e-01
epoch: 6 | train MSE: 1.3182e-01 | val MSE: 1.3089e-01
epoch: 7 | train MSE: 1.3224e-01 | val MSE: 1.3120e-01
epoch: 8 | train MSE: 1.2726e-01 | val MSE: 1.5270e-01
Epoch 00009: reducing learning rate of group 0 to 2.0000e-02.
INFO: scheduler refreshed
epoch: 9 | train MSE: 1.3367e-01 | val MSE: 1.5396e-01
epoch: 10 | train MSE: 1.3148e-01 | val MSE: 1.4872e-01
epoch: 11 | train MSE: 1.3443e-01 | val MSE: 1.5519e-01
epoch: 12 | train MSE: 1.1427e-01 | val MSE: 1.5510e-01
epoch: 13 | train MSE: 1.3167e-01 | val MSE: 1.5624e-01
epoch: 14 | train MSE: 1.1462e-01 | val MSE: 1.6599e-01
epoch: 15 | train MSE: 1.3348e-01 | val MSE: 1.5560e-01
epoch: 16 | train MSE: 1.2