## Imports

In [None]:
from __future__ import annotations

import os
import random
import subprocess

from dotenv import load_dotenv
import torch
import tweepy
from transformers import T5Tokenizer, AutoModelForCausalLM

In [None]:
load_dotenv()

In [None]:
subprocess.run(
    "python ./transformers/examples/pytorch/language-modeling/run_clm.py \
        --model_name_or_path=rinna/japanese-gpt2-small \
        --train_file=train_texts/timeline.txt \
        --do_train \
        --num_train_epochs=10 \
        --save_steps=10000 \
        --save_total_limit=3 \
        --per_device_train_batch_size=1 \
        --output_dir=finetuned_model/ \
        --overwrite_output_dir \
        --use_fast_tokenizer=False",
    shell=True,
)

In [None]:
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-small")
tokenizer.do_lower_case = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained("../finetuned_model/")
model = model.to(device)
model.eval()

In [None]:
def post_tweet(
    model: AutoModelForCausalLM,
    tokenizer: T5Tokenizer,
) -> None:

    client = tweepy.Client(
        consumer_key=os.getenv("API_KEY"),
        consumer_secret=os.getenv("API_SECRET_KEY"),
        bearer_token=os.getenv("BEARER_TOKEN"),
        access_token=os.getenv("ACCESS_TOKEN"),
        access_token_secret=os.getenv("ACCESS_TOKEN_SECRET"),
    )

    with open("../train_texts/timeline_today.txt", "r") as f:
        data = f.readlines()

    prompt = ""
    while not prompt:
        prompt = random.choice(data).replace("<s>", "").replace("</s>", "").replace("\n", "")

    if len(prompt) > 10:
        prompt = prompt[:10]

    input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(device)
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_length=60,
            min_length=10,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            bad_word_ids=[[tokenizer.unk_token_id]],
            repetition_penalty=0.99,
            num_return_sequences=1,
        )
    decoded = tokenizer.batch_decode(output, skip_special_tokens=True)
        
    json = {}
    json["text"] = decoded[0]
    
    print(decoded[0])
    return client._make_request("POST", "/2/tweets", json=json, user_auth=True)

In [None]:
post_tweet(model=model, tokenizer=tokenizer)