Skip to content

Commit

Permalink
Merge pull request #105 from kaushikb11/tpu_save/fix
Browse files Browse the repository at this point in the history
Fix save pretrained for TPUs
  • Loading branch information
minimaxir committed May 17, 2021
2 parents 0c0a099 + e9ac598 commit b22a20d
Showing 1 changed file with 38 additions and 16 deletions.
54 changes: 38 additions & 16 deletions aitextgen/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import ProgressBarBase
from tqdm.auto import tqdm
import os
import shutil
import subprocess
import sys

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
import os
import shutil
import subprocess

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import ProgressBarBase
from pytorch_lightning.utilities import _TPU_AVAILABLE


class ATGTransformer(pl.LightningModule):
Expand All @@ -18,12 +21,12 @@ class ATGTransformer(pl.LightningModule):

def __init__(self, model, dataset, hparams, tokenizer):
super(ATGTransformer, self).__init__()
self.model, self.dataset, self.hparams, self.tokenizer = (
self.model, self.dataset, self.tokenizer = (
model,
dataset,
hparams,
tokenizer,
)
self.save_hyperparameters(hparams)

def forward(self, inputs):
return self.model(**inputs, return_dict=False)
Expand Down Expand Up @@ -112,6 +115,10 @@ def __init__(
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.train_transformers_only = train_transformers_only
self.num_layers_freeze = num_layers_freeze

@property
def save_every_check(self):
return self.save_every > 0 and self.steps % self.save_every == 0

def enabled(self):
self.enabled = True
Expand Down Expand Up @@ -172,10 +179,19 @@ def on_batch_end(self, trainer, pl_module):
desc += f" — GPU Mem: {gpu_memory} MB"
self.main_progress_bar.update(self.progress_bar_refresh_rate)
self.main_progress_bar.set_description(desc)


if _TPU_AVAILABLE and self.save_every_check:
did_unfreeze = False
if self.enabled:
self.unfreeze_layers(pl_module)
did_unfreeze = True
self.save_pytorch_model(trainer, pl_module, tpu=True)
if did_unfreeze:
self.freeze_layers(pl_module)

if self.enabled:
did_unfreeze = False
if self.save_every > 0 and self.steps % self.save_every == 0:
if not _TPU_AVAILABLE and self.save_every_check:
self.unfreeze_layers(pl_module)
self.save_pytorch_model(trainer, pl_module)
did_unfreeze = True
Expand Down Expand Up @@ -219,13 +235,19 @@ def generate_sample_text(self, trainer, pl_module):

self.main_progress_bar.write("=" * 10)

def save_pytorch_model(self, trainer, pl_module):
self.main_progress_bar.write(
f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m"
)
pl_module.model.save_pretrained(self.output_dir)
def save_pytorch_model(self, trainer, pl_module, tpu=False):

if self.enabled:
self.main_progress_bar.write(
f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m"
)
if tpu:
import torch_xla.core.xla_model as xm
pl_module.model.save_pretrained(self.output_dir, save_function=xm.save)
else:
pl_module.model.save_pretrained(self.output_dir)

if self.save_gdrive:
if self.enabled and self.save_gdrive:
for pt_file in ["pytorch_model.bin", "config.json"]:
shutil.copyfile(
os.path.join(self.output_dir, pt_file),
Expand Down

0 comments on commit b22a20d

Please sign in to comment.