Skip to content

Commit

Permalink
Merge pull request #35 from minggnim/bugfix-metrics
Browse files Browse the repository at this point in the history
Bugfix metrics
  • Loading branch information
minggnim committed Jun 20, 2023
2 parents 2c8e56f + c6b4a5d commit 3f4f732
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 168 deletions.
273 changes: 139 additions & 134 deletions notebooks/01_a_classification_model_training_example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/02_b_multitask_model_training_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
"tokenizer_dir = output_dir / 'tokenizer'\n",
"model_file = output_dir / 'mtl.bin'\n",
"\n",
"model,_,_ = trainer.load_checkpoint('./drive/MyDrive/chatbot/chkpt/chkpt9.pt', model)\n",
"model,_,_ = trainer.load_checkpoint('../chkpt/chkpt9.pt', model)\n",
"pretrained_tokenizer.save_pretrained(tokenizer_dir)\n",
"AutoModelForMTL.save_model(model, model_file)"
]
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = nlp-models
version = 2.3.1
version = 2.3.2
author= Ming Gao
author_email = ming_gao@outlook.com
url = https://github.com/minggnim/nlp-models
Expand Down
10 changes: 8 additions & 2 deletions src/bert_classifier/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
'''
import torch
from torch.utils.data import Dataset
from .bert import bert_encoder


class CustomDataset(Dataset):
'''
Class to construct torch Dataset from dataframe
'''
def __init__(self, dataframe, data_field, label_field, tokenizer, max_len, multi_label):
def __init__(self,
dataframe,
data_field,
label_field,
tokenizer,
max_len,
multi_label=False
):
self.max_len = max_len
self.data = dataframe
self.tokenizer = tokenizer
Expand Down
6 changes: 3 additions & 3 deletions src/bert_classifier/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def loss_fn(outputs, targets):
return torch.nn.CrossEntropyLoss()(outputs, targets)


def optimizer(model, learning_rate):
optimizer = torch.optim.AdamW(
def optimizer_obj(model, learning_rate):
opt = torch.optim.AdamW(
params=model.parameters(),
lr=learning_rate
)
return optimizer
return opt


def custom_trainer(model, optimizer, train_dataloader, test_dataloader, epochs=5, device='cpu'):
Expand Down
4 changes: 2 additions & 2 deletions src/multi_task_model/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
def hamming_distance(
outputs,
targets,
multi_label: bool = False,
num_labels: Optional[int] = None,
num_labels: int,
multi_label: bool = False,
device: torch.device = torch.device('cpu'),
average: Literal['micro', 'macro', 'weighted', 'none'] = 'macro'
):
Expand Down
54 changes: 29 additions & 25 deletions src/multi_task_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class Configs:
epochs: int = 1
optimizer_class = torch.optim.AdamW
optimizer_params: Dict[str, float] = field(default_factory = lambda: ({"lr": 2e-5}))
optimizer_params: Dict[str, float] = field(default_factory=lambda: ({"lr": 2e-5}))
weight_decay: float = 0.01
scheduler: str = 'WarmupLinear'
warmup_steps: int = 10000
Expand All @@ -25,14 +25,14 @@ class Configs:


class Trainer:
def __init__(self,
def __init__(self,
model: torch.nn.Module,
train_dataloader: DataLoader,
test_dataloader: DataLoader,
configs: Configs,
metrics = accuracy,
metrics=accuracy,
device: torch.device = torch.device('cpu'),
chkpt_dir = Path('../chkpt')
chkpt_dir=Path('../chkpt')
):
self.model = model
self.train_dataloader = train_dataloader
Expand All @@ -54,19 +54,19 @@ def set_params(self):
param.requires_grad = False

self.optimizer = self.get_optimizer(
list(self.model.named_parameters()),
self.configs.optimizer_class,
self.configs.optimizer_params,
list(self.model.named_parameters()),
self.configs.optimizer_class,
self.configs.optimizer_params,
self.configs.weight_decay)

self.scheduler = self.get_scheduler(
self.optimizer,
self.configs.scheduler,
self.configs.warmup_steps,
len(self.train_dataloader)*self.configs.epochs)
self.optimizer,
self.configs.scheduler,
self.configs.warmup_steps,
len(self.train_dataloader) * self.configs.epochs)

def train(self):
epochs = tqdm(range(1, self.configs.epochs+1), leave = True, desc="Training...")
epochs = tqdm(range(1, self.configs.epochs + 1), leave=True, desc="Training...")
for epoch in epochs:
self.model.train()
epochs.set_description(f"EPOCH {epoch} / {self.configs.epochs} | training...")
Expand Down Expand Up @@ -105,10 +105,12 @@ def train_one_epoch(self, epoch):

total_train_loss += loss.item()
total_train_acc += self.metrics(
outputs[0], labels,
self.configs.multi_label,
self.configs.num_labels,
self.device).item()
outputs[0],
labels,
self.configs.num_labels,
self.configs.multi_label,
self.device
).item()

batches.set_description(f"Train Loss Step: {loss.item():.2f}")

Expand Down Expand Up @@ -141,10 +143,12 @@ def validate_one_epoch(self, epoch):

avg_val_loss = val_loss.mean().item()
avg_val_acc = self.metrics(
val_outputs, val_targets,
self.configs.multi_label,
self.configs.num_labels,
self.device).item()
val_outputs,
val_targets,
self.configs.num_labels,
self.configs.multi_label,
self.device
).item()

self.logger(epoch, avg_val_acc, avg_val_loss, 'test')

Expand Down Expand Up @@ -211,14 +215,14 @@ def save_checkpoint(self, epoch):

def schedule_cold_start(self):
self.scheduler = self.get_scheduler(
self.optimizer,
self.configs.scheduler,
0,
len(self.train_dataloader)*self.configs.epochs)
self.optimizer,
self.configs.scheduler,
0,
len(self.train_dataloader) * self.configs.epochs)

def print_per_epoch(self, epoch):
print(f"\n\n{'-'*30}EPOCH {epoch}/{self.configs.epochs}{'-'*30}")
epoch -= 1
epoch -= 1
train_loss = self.train_logs[epoch][f'epoch_{epoch}']['loss']
train_acc = self.train_logs[epoch][f'epoch_{epoch}']['accuracy']
val_loss = self.val_logs[epoch][f'epoch_{epoch}']['loss']
Expand Down

0 comments on commit 3f4f732

Please sign in to comment.