-
-
Notifications
You must be signed in to change notification settings - Fork 386
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
248 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from functools import partial | ||
|
||
from catalyst import dl, SETTINGS | ||
|
||
E2E = { | ||
"de": dl.DeviceEngine, | ||
"dp": dl.DataParallelEngine, | ||
"ddp": dl.DistributedDataParallelEngine, | ||
} | ||
|
||
if SETTINGS.amp_required: | ||
E2E.update( | ||
{"amp-dp": dl.DataParallelAMPEngine, "amp-ddp": dl.DistributedDataParallelAMPEngine} | ||
) | ||
|
||
if SETTINGS.apex_required: | ||
E2E.update( | ||
{"apex-dp": dl.DataParallelAPEXEngine, "apex-ddp": dl.DistributedDataParallelAPEXEngine} | ||
) | ||
|
||
if SETTINGS.deepspeed_required: | ||
E2E.update({"ds-ddp": dl.DistributedDataParallelDeepSpeedEngine}) | ||
|
||
if SETTINGS.fairscale_required: | ||
E2E.update( | ||
{ | ||
"fs-pp": dl.PipelineParallelFairScaleEngine, | ||
"fs-ddp": dl.SharedDataParallelFairScaleEngine, | ||
"fs-ddp-amp": dl.SharedDataParallelFairScaleAMPEngine, | ||
# for some reason we could catch a bug with FairScale flatten wrapper here, so... | ||
"fs-fddp": partial( | ||
dl.FullySharedDataParallelFairScaleEngine, ddp_kwargs={"flatten_parameters": False} | ||
), | ||
} | ||
) | ||
|
||
if SETTINGS.xla_required: | ||
E2E.update({"xla": dl.XLAEngine, "xla-ddp": dl.DistributedXLAEngine}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
#!/usr/bin/env python | ||
# flake8: noqa | ||
from argparse import ArgumentParser, RawTextHelpFormatter | ||
|
||
from common import E2E | ||
|
||
from datasets import load_dataset | ||
from torch import nn, optim | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.distributed import DistributedSampler | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_scheduler | ||
|
||
from catalyst import dl | ||
|
||
|
||
class CustomRunner(dl.IRunner): | ||
def __init__(self, logdir: str, engine: str): | ||
super().__init__() | ||
self._logdir = logdir | ||
self._engine = engine | ||
|
||
def get_engine(self): | ||
return E2E[self._engine]() | ||
|
||
def get_loggers(self): | ||
return { | ||
"console": dl.ConsoleLogger(), | ||
"csv": dl.CSVLogger(logdir=self._logdir), | ||
"tensorboard": dl.TensorboardLogger(logdir=self._logdir), | ||
} | ||
|
||
@property | ||
def stages(self): | ||
return ["train"] | ||
|
||
def get_stage_len(self, stage: str) -> int: | ||
return 10 | ||
|
||
def get_loaders(self, stage: str): | ||
datasets = load_dataset("glue", "sst2") | ||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") | ||
encoded_datasets = datasets.map( | ||
lambda examples: tokenizer( | ||
examples["sentence"], | ||
max_length=128, | ||
truncation=True, | ||
padding="max_length", | ||
), | ||
batched=True, | ||
) | ||
encoded_datasets = encoded_datasets.map(lambda x: {"labels": x["label"]}) | ||
encoded_datasets.set_format( | ||
type="torch", columns=["input_ids", "attention_mask", "labels"] | ||
) | ||
|
||
train_data = encoded_datasets["train"] | ||
valid_data = encoded_datasets["validation"] | ||
|
||
if self.engine.is_ddp: | ||
train_sampler = DistributedSampler( | ||
train_data, | ||
num_replicas=self.engine.world_size, | ||
rank=self.engine.rank, | ||
shuffle=True, | ||
) | ||
valid_sampler = DistributedSampler( | ||
valid_data, | ||
num_replicas=self.engine.world_size, | ||
rank=self.engine.rank, | ||
shuffle=False, | ||
) | ||
else: | ||
train_sampler = valid_sampler = None | ||
|
||
self.train_loader_len = len(DataLoader(train_data, batch_size=64, sampler=train_sampler)) | ||
|
||
return { | ||
"train": DataLoader(train_data, batch_size=64, sampler=train_sampler), | ||
"valid": DataLoader(valid_data, batch_size=32, sampler=valid_sampler), | ||
} | ||
|
||
def get_model(self, stage: str): | ||
model = ( | ||
self.model | ||
if self.model is not None | ||
else AutoModelForSequenceClassification.from_pretrained("albert-base-v2") | ||
) | ||
return model | ||
|
||
def get_criterion(self, stage: str): | ||
return nn.CrossEntropyLoss() | ||
|
||
def get_optimizer(self, stage: str, model): | ||
return optim.Adam(model.parameters(), lr=3e-5) | ||
|
||
def get_scheduler(self, stage: str, optimizer): | ||
scheduler = get_scheduler( | ||
"linear", | ||
optimizer=optimizer, | ||
num_warmup_steps=int(0.05 * self.train_loader_len) * self.stage_epoch_len, | ||
num_training_steps=self.train_loader_len * self.stage_epoch_len, | ||
) | ||
return scheduler | ||
|
||
def get_callbacks(self, stage: str): | ||
return { | ||
"criterion": dl.CriterionCallback( | ||
input_key="logits", target_key="labels", metric_key="loss" | ||
), | ||
"optimizer": dl.OptimizerCallback(metric_key="loss"), | ||
"scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss", mode="batch"), | ||
"accuracy": dl.AccuracyCallback( | ||
input_key="logits", target_key="labels", topk_args=(1,) | ||
), | ||
"checkpoint": dl.CheckpointCallback( | ||
self._logdir, | ||
loader_key="valid", | ||
metric_key="accuracy", | ||
minimize=False, | ||
save_n_best=1, | ||
), | ||
# "tqdm": dl.TqdmCallback(), | ||
} | ||
|
||
def handle_batch(self, batch): | ||
outputs = self.model(**batch) | ||
|
||
self.batch = { | ||
"features": batch["input_ids"], | ||
"labels": batch["labels"], | ||
"logits": outputs.logits, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser(formatter_class=RawTextHelpFormatter) | ||
parser.add_argument("--logdir", type=str, default=None) | ||
parser.add_argument("--engine", type=str, choices=list(E2E.keys())) | ||
args, _ = parser.parse_known_args() | ||
args.logdir = args.logdir or f"logs_albert_{args.engine}".replace("-", "_") | ||
runner = CustomRunner(args.logdir, args.engine) | ||
runner.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters