<a href="https://colab.research.google.com/github/buganart/dialog/blob/master/dialog.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Setup
# @markdown 1. Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`
# @markdown 2. Check GPU, should be a Tesla V100 if you want to train it as fast as possible.
# @markdown 3. Mount google drive.
# @markdown 4. Log in to wandb.


!nvidia-smi -L
import os

print(f"We have {os.cpu_count()} CPU cores.")
print()

try:
    from google.colab import drive, output

    IN_COLAB = True
except ImportError:
    from IPython.display import clear_output

    IN_COLAB = False

from pathlib import Path

if IN_COLAB:
    drive.mount("/content/drive/")

    if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
        raise RuntimeError(
            "Shortcut to our shared drive folder doesn't exits.\n\n"
            "\t1. Go to the google drive web UI\n"
            '\t2. Right click shared folder IRCMS_GAN_collaborative_database and click "Add shortcut to Drive"'
        )

clear = output.clear if IN_COLAB else clear_output


def clear_on_success(msg="Ok!"):
    if _exit_code == 0:
        clear()
        print(msg)


print()
print("Wandb installation and login ...")
%pip install -q wandb

wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
wandb_local_netrc_path = Path("/root/.netrc")
if wandb_drive_netrc_path.exists():
    import shutil

    print("Wandb .netrc file found, will use that to log in.")
    shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
else:
    print(
        f"Wandb config not found at {wandb_drive_netrc_path}.\n"
        f"Using manual login.\n\n"
        f"To use auto login in the future, finish the manual login first and then run:\n\n"
        f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
        f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
        f"Then that file will be used to login next time.\n"
    )

!wandb login

In [None]:
#@title Configuration

# Fill in the configuration then Then, select `Runtime` and `Run all` then let it ride!

#@markdown #### Training
# drive_dirs = list(Path('/content/drive').glob('*'))
# if not drive_dirs:
#     raise RuntimeError("Drive not found. Is it mounted?")
drive = Path('/content/drive/MyDrive')
print(f"Google drive at {drive}")    

drive_audio_db_root = drive
collaborative_database = drive / "IRCMS_GAN_collaborative_database"
experiment_dir = collaborative_database / "Experiments" / "colab-dialog" 
experiment_dir.mkdir(parents=True, exist_ok=True)

#@markdown The path of the text directory you would like to work with
text_dir = "/content/drive/MyDrive/IRCMS_GAN_collaborative_database/Research/Daniel/GPT2TEXT" #@param {type:"string"}
text_dir = Path(text_dir)
if not text_dir.exists():
    raise RuntimeError(f"The text_dir {text_dir} does not exist.")

#@markdown Name of pre-trained model (only tested `-small` so far).
pretrained_model = "microsoft/DialoGPT-medium" #@param ["\"microsoft/DialoGPT-small\"", "\"microsoft/DialoGPT-medium\"", "\"microsoft/DialoGPT-large\""] {type:"raw", allow-input: true}

#@markdown training parameters
batch_size = 1 #@param {type: "integer"}
num_context = 7 #@param {type: "integer"}


# #@markdown [Optional] ID of wandb run to resume.
# #resume_run_id = "" #@param {type: "string"}

def check_wandb_id(run_id):
    import re
    if run_id and not re.match(r"^[\da-z]{8}$", run_id):
        raise RuntimeError(
            "Run ID needs to be 8 characters long and contain only letters a-z and digits.\n"
            f"Got \"{run_id}\""
        )

# check_wandb_id(resume_run_id)

config = dict(
    text_dir=text_dir,
    experiment_dir=experiment_dir,
    pretrained_model=pretrained_model,
    # resume_run_id=resume_run_id,
)
for k,v in config.items():
    print(f"=> {k:20}: {v}")

In [None]:
#@title Clone `buganart/dialog` repo.
if IN_COLAB:
    !git clone https://github.com/buganart/dialog
    clear_on_success("Repo cloned!")

In [None]:
#@title Install dependencies
%pip install torch==1.4.0
%pip install -e ./dialog
import site
site.main()
clear_on_success("Dependencies installed!")

In [5]:
#@title Finetune
env_vars = dict(
    WANDB_ENTITY="bugan",
    WANDB_PROJECT="dialog",
)
for name, value in env_vars.items():
    os.environ[name] = value

from dialog import finetune

trainer = finetune.train(
    text_dir=text_dir, 
    save_dir=experiment_dir, 
    pretrained_model=pretrained_model,
    num_context=num_context,
    batch_size=batch_size,
)

import wandb
wandb.join()

In [None]:
#@title Generate text.
import torch
from dialog import generate

#@markdown Prefix for text generation.
prefix = "Hurry up Morty!" #@param {type:"string"}
#@markdown Number of responses to genrate.
steps =  20#@param {type:"integer"}

#@markdown [Optional] Checkpoint path with trained model. **No need to specify if you just trained the model**.
checkpoint_dir = "" #@param {type:"string"}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if checkpoint_dir:
    print(f"Loading models from checkpoint {checkpoint_dir}")
    model = generate.load_model(checkpoint_dir).to(device)
    tokenizer = generate.load_tokenizer(checkpoint_dir)
else:
    model = trainer.model.to(device)
    tokenizer = trainer.tokenizer

generate.generate(
    model=model,
    tokenizer=tokenizer,
    device=device,
    prefix=prefix,
    steps=steps,
)