In [None]:
def get_losses(model_names):
    losses = []
    timesteps = []
    eval_losses = []
    eval_timesteps = []
    for m in model_names:
        losses_model = []
        timesteps_model = []
        eval_losses_model = []
        eval_timesteps_model = []
        path = os.path.join(PATH_TO_CHECKPOINTS, m)
        newest = 0 #most recent checkpoint has logs from every step
        for d in os.listdir(path):
            if d[0: 10] == "checkpoint":
                newest = max(newest, int(d[11:])) #index goes from 10 to 11 because skipping the "-"
                
        newest_path = os.path.join(path, "checkpoint-" + str(newest) + "/trainer_state.json")
        with open(newest_path, "r") as f:
            state = json.load(f)
            #each log is a dict
            for log in state["log_history"]:
                if "eval_loss" in log.keys():
                    eval_losses_model.append(log["eval_loss"])
                    eval_timesteps_model.append(log["step"])
                else:
                    losses_model.append(log["loss"])
                    timesteps_model.append(log["step"])
        
        
        losses.append(losses_model)
        timesteps.append(timesteps_model)
        eval_losses.append(eval_losses_model)
        eval_timesteps.append(eval_timesteps_model)

    return losses, timesteps, eval_losses, eval_timesteps      
                    


def plot_losses(losses, timesteps, model_names, metric_name):
    colors = iter(plt.cm.viridis(np.linspace(0, 1, len(losses))))
    for l, t, c, n in zip(losses, timesteps, colors, model_names):
        plt.plot(t, l, '.', color=c)
        plt.plot(t, l, alpha=1, color=c, label=n)
    
    plt.legend(loc="upper right")
    plt.xlabel("Batches")
    plt.ylabel(metric_name)
    plt.title(metric_name)
    plt.grid()
    plt.show()
    
def save_list(strings, out_path):
    strings = [strings] if isinstance(strings, str) else strings
    with open(out_path, "wt") as f:
        for i in strings:
            f.write(i + "\n\n\n")
          
        
def save_model_outs(articles, model, tokenizer, out_path):
    device = 0 if torch.cuda.is_available() else -1
    pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1)
    summs = [d["summary_text"] for d in pipe(articles)]
    save_list(summs, out_path)

#TODO: LOAD FROM CHECKPOINT FOR AFTER
#given list of articles, save before and after
def before_and_after(articles, model_name):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATHS[model_name], local_files_only=True)
    before_model = AutoModelForSeq2SeqLM.from_pretrained(UNTRAINED_PATHS[model_name], local_files_only=True)
    freeze_weights(before_model)
    before_path = model_name + "before.txt"
    save_model_outs(articles, before_model, tokenizer, before_path)

    after_model = AutoModelForSeq2SeqLM.from_pretrained(TRAINED_PATHS[model_name], local_files_only=True)
    freeze_weights(after_model)
    after_path = model_name + "_after.txt"            
    save_model_outs(articles, after_model, tokenizer, after_path)
    