In [6]:
import torch

# 1 logger

In [13]:
import logging
from  tqdm import tqdm

In [14]:
def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)

    return logger

# 2 evaluate utils

In [15]:
def accuracy(y_pred,y):
    y_pred = torch.argmax(y_pred,-1)
    return sum(y_pred == y)

In [16]:
class Accumulator:
    def __init__(self,n):
        self.data = [0.0]*n
    
    def add(self,*args):
        self.data = [a+float(b) for a,b in zip(self.data,args)]
        
    def reset(self):
        self.data = [0.0]*n
    
    def __getitem__(self,idx):
        return self.data[idx]

In [17]:
def eval_acc(net,data_iter):
    if isinstance(net,torch.nn.Module):
        net.eval()
    metric = Accumulator(2)
    with torch.no_grad():
        for x,y in data_iter:
            pred_y = net(y)
            metric.add(accuracy(y_pred,y),y.numel())
    return metric[0]/metric[1]

# 3 train utils

In [12]:
import transformers
from transformers import set_seed
set_seed(42)

In [1]:
def train_epoch(data_iter,net,loss,optimize):
    if isinstance(net,torch.nn.Module):
        net.train()
    metric = Accumulator(3)
    for x,y in data_iter:
        y_pred = net(x)
        loss_tmp = loss(y_pred,y)
        metric.add(loss_tmp*len(y),accuracy(y_pred,y),y.numel())
        ##optimize
        optimize.zero_grad()
        loss_tmp.backward()
        optimize.step()
    return metric[0]/metric[2],metric[1]/metric[2]

In [19]:
def train(train_iter,test_iter,net,loss,optimize,num_epochs,logger):
    for epoch in tqdm(range(num_epochs)):
        train_metrics = train_epoch(train_iter,net,loss,optimize)
        test_acc = evaluate_acc(net,test_loader)
        print('epoch:%d\ttrain_loss:%f\ttrain_acc:%f\ttest_acc:%f'%(epoch,train_metrics[0],train_metrics[1],test_acc))
        logger.info('epoch:%d\ttrain_loss:%f\ttrain_acc:%f\ttest_acc:%f'%(epoch,train_metrics[0],train_metrics[1],test_acc))
    train_loss,train_acc = train_metrics

# 4 optimize set

In [13]:

model = torch.nn.Sequential(torch.nn.Linear(in_features=100,out_features=2))

##optimize
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
{
    "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
    "weight_decay": 0.0,
},
]
optimizer = transformers.AdamW( optimizer_grouped_parameters, 
                               lr=3e-5,betas=(0.9, 0.999),eps=1e-08)

##learnign_rate adjust
num_training_step=100
lr_scheduler = transformers.optimization.get_scheduler(name='linear',
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_step
)

##grad clip
torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
print(optimizer)
print('Debug optimi',optimizer.state_dict())
print('Debug!!!sche',lr_scheduler.state_dict())


AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-08
    initial_lr: 3e-05
    lr: 3e-05
    weight_decay: 0.0

Parameter Group 1
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-08
    initial_lr: 3e-05
    lr: 3e-05
    weight_decay: 0.0
)
Debug optimi {'state': {}, 'param_groups': [{'weight_decay': 0.0, 'lr': 3e-05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'correct_bias': True, 'initial_lr': 3e-05, 'params': [0]}, {'weight_decay': 0.0, 'lr': 3e-05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'correct_bias': True, 'initial_lr': 3e-05, 'params': [1]}]}
Debug!!!sche {'base_lrs': [3e-05, 3e-05], 'last_epoch': 0, '_step_count': 1, 'verbose': False, '_get_lr_called_within_step': False, '_last_lr': [3e-05, 3e-05], 'lr_lambdas': [None, None]}
