# Goal of this notebook

Develop a training loop for finetuning ASR models using TTS loss by recreating RL training found in RL4LMs/rl4lms/envs/text_generation/training_utils.py

# automatic reloading magic

# imports

In [2]:
# print hostname to make sure we are on correct node
import socket
print(socket.gethostname())

strickland.inf.ed.ac.uk


In [3]:
import os
os.getcwd()

'/disk/nfs/ostrom/s1785140/rlspeller'

In [4]:
import torch
from typing import List, Dict, Tuple, Any
import hyperpyyaml
from tqdm import tqdm
from torchaudio.models.decoder import ctc_decoder
from torch.nn.functional import softmax
import random
from jiwer import cer
import numpy as np

In [5]:
torch.cuda.is_available()

True

In [6]:
import speechbrain as sb

In [7]:
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# HPARAMS

In [8]:
hparams = {
    "softdtw_temp": 0.01,
    "softdtw_bandwidth": 120,
    "dist_func": "l1",
    "sentencepiece_model_path": "/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/0_char.model",
    'speechbrain_hparams_file': '/home/s1785140/rlspeller/infer_speechbrain.yaml',
}

# TOKENIZER

In [9]:
# load pretrained tokenizer used to tokenizer ASR training inputs 
import sentencepiece as spm 
spm_path = hparams["sentencepiece_model_path"]
sp = spm.SentencePieceProcessor()
sp.load(spm_path)
print(sp.vocab_size())

28


In [10]:
# test tokenizer
s = "hello world my name is jason"
# TODO pass string through text cleaners? 
encoded = sp.EncodeAsIds(s)
assert 0 not in encoded, "tried to encode an unknown character"
print(" ".join(str(idx) for idx in encoded))

1 10 2 12 12 4 1 17 4 9 12 11 1 16 20 1 6 5 16 2 1 7 8 1 26 5 8 4 6


In [11]:
sp.DecodeIds(encoded)

'hello world my name is jason'

# NEW! SIMPLE TOKENIZER

In [12]:
from speechbrain.tokenizers.SimpleTokenizer import SimpleTokenizer

In [13]:
tokenizer = SimpleTokenizer()

In [14]:
text = "hello my name is jason"
text = text.replace(' ', '|')
print(text)
ids = tokenizer.encode_as_ids(text)
ids

hello|my|name|is|jason


[9, 6, 13, 13, 16, 1, 14, 26, 1, 15, 2, 14, 6, 1, 10, 20, 1, 11, 2, 20, 16, 15]

In [15]:
tokenizer.decode_ids(ids)

'hello|my|name|is|jason'

## test simple tokenizer with probability distribution, and see if CTC decoder successfully generates n-best lists

In [16]:
# create empty array of correct dimensions
min_len, max_len = 50, 100
bsz = 4
lens = torch.randint(min_len, max_len, (bsz,))
vocab_size = len(tokenizer.vocab)

# randomly assign probaility distribution to each timestep

# try to decode

In [17]:
randn = torch.randn(bsz, max_len, vocab_size)

In [18]:
ctc_probs = softmax(randn, dim=1)
# ctc_probs

In [19]:
ctc_beamsearch_decoder_test = ctc_decoder(
    lexicon=None,
    # tokens="/home/s1785140/rlspeller/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/tokens.txt",
    tokens=tokenizer.vocab,
    nbest=2,
    blank_token='-',
    sil_token="|",
)

predicted_ids = ctc_beamsearch_decoder_test(ctc_probs, lens)

predicted_words = []
for i, hyps in enumerate(predicted_ids):
    for j, hyp in enumerate(hyps):
        words = tokenizer.decode_ids(hyp.tokens.tolist()).split(" ")
        tup = (f"sample {i+1}, hyp {j+1}/{len(hyps)}", words)
        predicted_words.append(tup)
        print(tup)

('sample 1, hyp 1/2', ['|ydjno|dumnf|ramucrtgolkjfsyipquyvuctvghwnbctnijveywrzgzy|'])
('sample 1, hyp 2/2', ['|ydjno|dumnf|ramucrtgolk|jfsyipquyvuctvghwnbctnijveywrzgzy|'])
('sample 2, hyp 1/2', ['|yiewfjhwqouhdvmewtrlekncpkceorxdsonvqhnspgcgac|gaxm|'])
('sample 2, hyp 2/2', ['|yiewfjhwqouhdvmewtrlekncpkceorxdsonvqhnspcgac|gaxm|'])
('sample 3, hyp 1/2', ['|plmqd|ublcbwdeodn|bhpfptegqwleomvsudveyarodysrt|ta|i|'])
('sample 3, hyp 2/2', ['|plmqd|ublcbweodn|bhpfptegqwleomvsudveyarodysrt|ta|i|'])
('sample 4, hyp 1/2', ['|lhpbsmrzrcpnjmnwpvrxepqwuamizxafkmvmluinhlifbsnbkqyrqnuyq|lfuqhkcvczmnvqgoiyzbrtzids|'])
('sample 4, hyp 2/2', ['|lhpbsmrzrcpnjmnwpvrxepqwuamizxafkmvmluinhlifbsnbkqyrqnuyq|lfuqhkcvczmnvqpoiyzbrtzids|'])


# LOAD ASR (PRETRAINED)

In [20]:
from templates.speech_recognition_CharTokens_NoLM.ASR.train import ASR
from templates.speech_recognition_CharTokens_NoLM.ASR.train import dataio_prepare
from torch.utils.data import DataLoader

In [21]:

from speechbrain.dataio.dataloader import LoopedLoader

In [22]:
# Load hyperparameters file with command-line overrides
speechbrain_hparams_file = hparams['speechbrain_hparams_file']
with open(speechbrain_hparams_file) as f:
    speechbrain_hparams = hyperpyyaml.load_hyperpyyaml(f)

/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/rirs_noises.zip exists. Skipping download


In [23]:
speechbrain_hparams['save_folder']

'/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/ASR/results/CRDNN_CHAR_LJSpeech_halved/2602/save'

In [24]:
# initialise trainer (we don't want to train, but model is tightly coupled with trainer)
asr_brain = ASR(
    modules=speechbrain_hparams["modules"],
    opt_class=speechbrain_hparams["opt_class"],
    hparams=speechbrain_hparams,
    checkpointer=speechbrain_hparams["checkpointer"],
)

def setup_asr_brain_for_infer(asr_brain):
    asr_brain.on_evaluate_start(min_key="WER") # We call the on_evaluate_start that will load the best model
    asr_brain.modules.eval() # We set the model to eval mode (remove dropout etc)

print("if on_evaluate_start() get runtime error, likely need to restart notebook kernel")
setup_asr_brain_for_infer(asr_brain)

if on_evaluate_start() get runtime error, likely need to restart notebook kernel


In [25]:
# create dataset and dataloader for inference
datasets = dataio_prepare(speechbrain_hparams)

test_set = datasets['test']

if not isinstance(test_set, DataLoader) or isinstance(test_set, LoopedLoader):
    test_loader_kwargs=speechbrain_hparams["test_dataloader_opts"]
    test_set = asr_brain.make_dataloader(
        test_set, stage=sb.Stage.TEST, **test_loader_kwargs
    )

In [26]:
# get vocab from tokenizer (needed for ctc decoding)
vocab_size = len(asr_brain.hparams.tokenizer)
vocab = []
for i in range(vocab_size):
    vocab.append(asr_brain.hparams.tokenizer.decode_ids([i]))
print(vocab)

# edit vocab to match default ctc decoder symbols for blank and silence
vocab[0] = '-'
vocab[1] = "|"

print(vocab)

[' ⁇ ', '', 'e', 't', 'o', 'a', 'n', 'i', 's', 'r', 'h', 'd', 'l', 'c', 'f', 'u', 'm', 'w', 'p', 'g', 'y', 'b', 'v', 'k', 'x', 'q', 'j', 'z']
['-', '|', 'e', 't', 'o', 'a', 'n', 'i', 's', 'r', 'h', 'd', 'l', 'c', 'f', 'u', 'm', 'w', 'p', 'g', 'y', 'b', 'v', 'k', 'x', 'q', 'j', 'z']


In [27]:
ctc_beamsearch_decoder = ctc_decoder(
    lexicon=None,
    # tokens="/home/s1785140/rlspeller/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/tokens.txt",
    tokens=vocab,
    nbest=100,
    blank_token='-',
    sil_token="|",
)

In [28]:
# generate transcriptions for all batches in test set
def transcribe_dataset(asr_brain, dataset, greedy=False, num_batches_to_transcribe=None):
    # Now we iterate over the dataset and we simply compute_forward and decode
    with torch.no_grad():
        transcripts = []
        for batch in tqdm(list(dataset)[:num_batches_to_transcribe], dynamic_ncols=True):
            orig_transcriptions = batch.words

            # Make sure that your compute_forward returns the predictions !!!
            # In the case of the template, when stage = TEST, a beam search is applied 
            # in compute_forward(). 
            predictions = asr_brain.compute_forward(batch, stage=sb.Stage.TEST)
            
            ctc_probs = predictions['ctc_logprobs'] # FOR DEBUG

            if greedy:
                predicted_ids = sb.decoders.ctc_greedy_decode(
                    predictions["ctc_logprobs"], asr_brain.feat_lens, blank_id=asr_brain.hparams.blank_index
                )
                predicted_words = [
                    asr_brain.tokenizer.decode_ids(ids).split(" ")
                    for ids in predicted_ids
                ]
            else:
                # get mel lens from wav len ratios since torch ctc decoder requires lens in frames
                batch_max_len = predictions["ctc_logprobs"].size(1)
                bsz = predictions["ctc_logprobs"].size(0)
                mel_lens = torch.zeros(bsz)
                for i, len_ratio in enumerate(asr_brain.feat_lens):
                    mel_lens[i] = int(torch.round(len_ratio * batch_max_len))
                
                predicted_ids = ctc_beamsearch_decoder(
                    predictions["ctc_logprobs"], lengths=mel_lens
                )

                predicted_words = []
                for i, (utt_id, orig_text, hyps) in enumerate(zip(batch.utt_id, orig_transcriptions, predicted_ids)):
                    print(f"\nsample {i+1} - ({utt_id}: '{orig_text}')")
                    sample_cers = []
                    for j, hyp in enumerate(hyps):
                        words = asr_brain.hparams.tokenizer.decode_ids(hyp.tokens.tolist()) # .split("|")
                        # words = tokenizer.decode_ids(hyp.tokens.tolist()) # .split("|")
                        hyp_cer = 100 * cer(orig_text, words)
                        sample_cers.append(hyp_cer)
                        print(f"\thyp {j+1}/{len(hyps)} (CER={hyp_cer:.1f}%): '{words}'")
                        predicted_words.append((f"sample {i+1}, hyp {j+1}/{len(hyps)}", words))
                        
                    print(f"\t=== Mean CER: {np.mean(sample_cers):.1f}%, Std CER: {np.std(sample_cers):.1f}% ===")

            transcripts.append(predicted_words)

    return transcripts, ctc_probs

transcripts, ctc_probs = transcribe_dataset(asr_brain, test_set, greedy=False, num_batches_to_transcribe=1)

  0%|                                                                                        | 0/1 [00:00<?, ?it/s]

DEBUG INSIDE PREPARE FEATURES, feats.shape=torch.Size([8, 627, 40]) wav_lens.shape=torch.Size([8])


100%|████████████████████████████████████████████████████████████████████████████████| 1/1 [00:08<00:00,  8.84s/it]


sample 1 - (LJ039-0175: 'for the first four attempts the firers missed the second shot by several inches')
	hyp 1/48 (CER=0.0%): 'for the first four attempts the firers missed the second shot by several inches '
	hyp 2/48 (CER=1.3%): 'for the first four attempts the firers mised the second shot by several inches '
	hyp 3/48 (CER=1.3%): 'for the fist four attempts the firers missed the second shot by several inches '
	hyp 4/48 (CER=1.3%): 'for the first four attempths the firers missed the second shot by several inches '
	hyp 5/48 (CER=2.5%): 'for the first four attempts the firerers missed the second shot by several inches '
	hyp 6/48 (CER=2.5%): 'for the fist four attempts the firers mised the second shot by several inches '
	hyp 7/48 (CER=1.3%): 'fr the first four attempts the firers missed the second shot by several inches '
	hyp 8/48 (CER=1.3%): 'for the first four attempts the firers  missed the second shot by several inches '
	hyp 9/48 (CER=1.3%): 'for the first four atempts the




In [29]:
raise ValueError("stop here for development of ctc beam search decoder")

ValueError: stop here for development of ctc beam search decoder

# DATAPOOL

In [None]:
## load SpellerDataset

In [None]:
from rlspeller.dataset import SpellerDataset

def load_dataset(split):
    wordaligned_speechreps_dir = '/home/s1785140/data/ljspeech_speechbrain/wordaligned_mels'
    wordlists = {
        "train": '/home/s1785140/data/ljspeech_fastpitch/respeller_train_words.json',
        "val": '/home/s1785140/data/ljspeech_fastpitch/respeller_val_words.json',
        "test": '/home/s1785140/data/ljspeech_fastpitch/respeller_test_words.json',
    }
    
    return SpellerDataset(wordaligned_speechreps_dir, wordlists[split])

In [None]:
# define Sample and Datapool

In [None]:
from dataclasses import dataclass
from abc import abstractclassmethod

@dataclass(init=True)
class Sample:
    id: str # 
    gt_mel_path: str # full path to mel spectrogram
    gt_text: str # original spelling of word
    meta_data: Dict[str, Any] = None


class ASRPool:
    def __init__(self, samples: List[Sample]):
        self._samples = samples

    def __len__(self):
        return len(self._samples)

    def __getitem__(self, ix: int) -> Sample:
        if ix >= len(self):
            raise StopIteration
        sample = self._samples[ix]
        return sample, 1.0

    def sample(self) -> Sample:
        random_sample = random.choice(self._samples)
        return random_sample

    @abstractclassmethod
    def prepare(cls, **args) -> 'TextGenPool':
        """
        A factory method to instantiate data pool
        """
        raise NotImplementedError

    def split(self, split_ratios: List[float]) -> List['TextGenPool']:
        start_ix = 0
        pools = []
        for ratio in split_ratios:
            count = int(len(self) * ratio)
            end_ix = start_ix + count
            pools.append(type(self)(self._samples[start_ix: end_ix]))
            start_ix = end_ix
        return pools
    
class LJSpeech(ASRPool):
    @classmethod
    def prepare(cls, split: str, **args) -> 'ASRPool':
        ds = load_dataset(split)
        samples = []
        for idx, item in tqdm(enumerate(ds)):
            # sample = Sample(
            #     id: f"{split}_{idx}" 
            #     gt_mel_path: str 
            #     gt_text: str 
            # )
            samples.append(sample)

        pool_instance = cls(samples)
        return pool_instance

In [None]:
train_datapool = LJSpeech.prepare('train')

In [None]:
train_datapool._samples

In [None]:
train_datapool = _get_datapool_by_split("train")
val_datapool = _get_datapool_by_split("val")
test_datapool = _get_datapool_by_split("test")

samples_by_split = {
    "train": [(sample, weight)
              for sample, weight in train_datapool],
    "val": [sample for sample, _ in val_datapool],
    "test": [sample for sample, _ in test_datapool]
}
return samples_by_split

# REWARD FUNCTION

## funcs to load pretrained fastpitch model

In [None]:
import argparse
from fastpitch import models as fastpitch_model

parser = argparse.ArgumentParser(description='Fastpitch Model Config Parser', allow_abbrev=False)
parser = fastpitch_model.parse_model_args('FastPitch', parser)
args, unk_args = parser.parse_known_args()

In [None]:
print("WARNING!!! unknown args:", unk_args)

### Training command for no punctuation fastpitch:

```bash
cd 
source activate_respeller.sh

cd ~/respeller/fastpitch

EXP_NAME=halved_ljspeech_data_nospaces_noeos_pad_lowercase_nopunc

DATA_ROOT=~/data/ljspeech_fastpitch
CHECKPOINT_DIR=exps
mkdir $CHECKPOINT_DIR
HIFIGAN_CHKPT=~/pretrained_models/hifigan/ljspeech/LJ_V1/generator_v1
HIFIGAN_CFG=~/pretrained_models/hifigan/ljspeech/LJ_V1/config.json
MASTER_ADDR=`hostname -s`
FILELIST_STEM=wav_text_filelist

./sbatch.sh python train.py \
  --dataset-path $DATA_ROOT \
  --output $CHECKPOINT_DIR/$EXP_NAME \
  --training-files $DATA_ROOT/train_meta_half.txt \
  --validation-files $DATA_ROOT/val_meta_half.txt \
  --pitch-mean-std-file $DATA_ROOT/pitches_stats__${FILELIST_STEM}.json \
  --input-type char \
  --symbol-set english_pad_lowercase_nopunc \
  --text-cleaners lowercase_no_punc \
  --epochs 1000 \
  --epochs-per-checkpoint 10 \
  --batch-size 16 \
  --use-mas \
  --cuda \
  --hifigan $HIFIGAN_CHKPT \
  --hifigan-config $HIFIGAN_CFG \
  --use-sepconv \
  --master-addr $MASTER_ADDR \
  --checkpoint-path /home/s1785140/respeller/fastpitch/exps/halved_ljspeech_data_nospaces_noeos_pad_lowercase_nopunc/FastPitch_checkpoint_290.pt
```

In [None]:
# change values of some args to match the config of the pretrained model 
args.local_rank = 0
args.use_mas = True
args.use_sepconv = True
args.cuda = torch.cuda.is_available()
args.input_type = 'char'
args.symbol_set = 'english_pad_lowercase_nopunc'
args.n_speakers = 1
args.fastpitch_chkpt = "/home/s1785140/respeller/fastpitch/exps/halved_ljspeech_data_nospaces_noeos_pad_lowercase_nopunc/FastPitch_checkpoint_1000.pt"

In [None]:
def load_checkpoint(args, model, filepath):
    if args.local_rank == 0:
        print(f'Loading model and optimizer state from {filepath}')
    checkpoint = torch.load(filepath, map_location='cpu')
    sd = {k.replace('module.', ''): v
          for k, v in checkpoint['state_dict'].items()}
    getattr(model, 'module', model).load_state_dict(sd)
    return model

def load_pretrained_fastpitch(args):
    # load chkpt
    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = fastpitch_model.get_model_config('FastPitch', args)
    fastpitch = fastpitch_model.get_model('FastPitch', model_config, device, forward_is_infer=True)
    load_checkpoint(args, fastpitch, args.fastpitch_chkpt)
    # get information about grapheme embedding table
    n_symbols = fastpitch.encoder.word_emb.weight.size(0)
    embedding_dim = fastpitch.encoder.word_emb.weight.size(1)
    return fastpitch, model_config, n_symbols, embedding_dim

# from fastpitch.fastpitch.transformer import FFTransformer
fastpitch, model_config, n_symbols, embedding_dim = load_pretrained_fastpitch(args)
print("Finished loading TTS model!")

## TTSMetric

In [None]:
class TTSMetric:
    def __init__(
        self,
        model_path,
    ):
        self.tts_model = load_fastpitch(model_path)
        self.softdtw_loss = criterion = SoftDTW(
            use_cuda=torch.cuda.is_available(), 
            gamma=hparams["softdtw_temp"], 
            bandwidth=hparams["softdtw_bandwidth"],
            dist_func=hparams["dist_func"],
        )
    
    def __call__(
        self,
        predicted_texts: List[str], # [bsz]
        reference_mels: torch.Tensor, # [bsz, seqlen, dim]
    ) -> float:
        """return softdtw loss between two batches of mel-spectrograms
        averaged across batch dimension"""
        predicted_mels = self.tts_model(predicted_texts)
        return self.softdtw_loss(predicted_mels, reference_mels).mean()

In [None]:
ttsmetric = TTSMetric(hparams["tts_model_path"])

## TTSRewardFunction

In [None]:
class TTSRewardFunction:
    """TTS reward function"""
    def __init__(
        self, 
        model_path: str,
        shaping_fn: str = None,
    ):
        super().__init__()
        self._metric = TTSMetric(model_path)
        
    def __call__(
        self,
        current_observation: Observation,
        action: int,
        next_observation: Observation,
        done: bool,
        meta_info: Dict[str, Any] = None,
    ):
        

# ENVIRONMENT

## create custom env

In [None]:
import gym
from gym import spaces

class ASREnv(gym.Env):
    """Custom Environment that follows gym interface."""
    
    # below taken from Gym code https://github.com/openai/gym/blob/master/gym/core.py
    r"""The main OpenAI Gym class.
    It encapsulates an environment with arbitrary behind-the-scenes dynamics.
    An environment can be partially or fully observed.
    The main API methods that users of this class need to know are:
    - :meth:`step` - Takes a step in the environment using an action returning the next observation, reward,
      if the environment terminated and observation information.
    - :meth:`reset` - Resets the environment to an initial state, returning the initial observation and observation information.
    - :meth:`render` - Renders the environment observation with modes depending on the output
    - :meth:`close` - Closes the environment, important for rendering where pygame is imported
    And set the following attributes:
    - :attr:`action_space` - The Space object corresponding to valid actions
    - :attr:`observation_space` - The Space object corresponding to valid observations
    - :attr:`reward_range` - A tuple corresponding to the minimum and maximum possible rewards
    - :attr:`spec` - An environment spec that contains the information used to initialise the environment from `gym.make`
    - :attr:`metadata` - The metadata of the environment, i.e. render modes
    - :attr:`np_random` - The random number generator for the environment
    Note: a default reward range set to :math:`(-\infty,+\infty)` already exists. Set it if you want a narrower range.
    """

    metadata = {"render.modes": ["human"]}

    def __init__(
        self, 
        tokenizer, 
        reward_function,
        samples,
        
    ):
        """Generic RL environment to generate ASR hypotheses from input audio"""
        super().__init__()
        
        self._vocab_size = tokenizer.vocab_size
        self.reward_function = reward_function
        for sample, weight in samples:
            self.sampler_for_replaying.add(sample, weight)
        
        # Define action and observation space
        # They must be gym.spaces objects
        self.action_space = spaces.Discrete(n=self._vocab_size)
        self.observation_space = DictSpace(
            {
                # we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited
                # while creating rollout buffers, observations are concatenated for each key
                "prompt_or_input_encoded_pt": spaces.Box(
                    low=0, high=self._vocab_size, shape=(self._max_text_length,)
                ),
                "prompt_or_input_attention_mask_pt": spaces.Box(
                    low=0, high=1, shape=(self._max_text_length,)
                ),
                "context_encoded_pt": spaces.Box(
                    low=0, high=self._vocab_size, shape=(self.max_steps,)
                ),
                "context_attention_mask_pt": spaces.Box(
                    low=0, high=1, shape=(self.max_steps,)
                ),
                "input_encoded_pt": spaces.Box(
                    low=0,
                    high=self._vocab_size,
                    shape=(self._max_text_length + self.max_steps,),
                ),
                "input_attention_mask_pt": spaces.Box(
                    low=0, high=1, shape=(self._max_text_length + self.max_steps,)
                ),
            }
        )

    def step(self, action):
        self.__time_step += 1

        # previous obs
        previous_obs = self.__current_obs

        # just update the context tensor and gets the new observation
        self.__current_obs = self.__current_obs.update(action, self.tokenizer)

        # decide if the episode is finished or not
        done = (action == self.tokenizer.eos_token_id and self._terminate_on_eos) or (
            self.__time_step == self.max_steps
        )

        # compute reward
        if not isinstance(self.reward_function, BatchedRewardFunction):
            reward = (
                None
                if self.reward_function is None
                else self.reward_function(
                    previous_obs,
                    action,
                    self.__current_obs,
                    done,
                    self.__current_obs.meta_info,
                )
            )
        else:
            reward = -inf  # will be overridden later

        # populate additional info
        info = {
            "output": self.__current_obs.context_text,
            "action_history": self.__current_obs.action_history,
            "reference_text": self.__current_obs.target_or_reference_texts,
            "prompt_text": self.__current_obs.prompt_or_input_text,
            "prev_output": previous_obs.context_text,
            "meta_info": previous_obs.meta_info,
        }

        dict_observation = self.__current_obs.to_dict()
        return dict_observation, reward, done, info

    def reset(self):
        """
        Resets the environment and starts a new episode
        """
        # gets a new sample if not provided
        if sample is None:
            sample = self.sampler_for_replaying.sample(size=1)[0]
        self.__current_sample = sample

        # init the observation
        self.__current_obs = Observation.init_from_sample(
            sample,
            self.tokenizer,
            self._max_text_length,
            self.max_steps,
            self._prompt_truncation_side,
            self._context_start_token,
            sample.meta_data,
        )

        # start the time step counter
        self.__time_step = 0

        dict_observation = self.__current_obs.to_dict()
        return dict_observation

    def render(self):
        pass

    def close(self):
        pass

## check that env follows Gym interface

In [None]:
from stable_baselines3.common.env_checker import check_env

env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)

# POLICY/ALGORITHM

In [None]:
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

from gym import spaces
import torch as th
from torch import nn

from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy


class PPO(OnPolicyAlgorithm):
    """
    Created with reference to Seq2SeqLMActorCriticPolicy
    
    Custom network for policy and value function.
    It receives as input the features extracted by the features extractor.

    :param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
    :param last_layer_dim_pi: (int) number of units for the last layer of the policy network
    :param last_layer_dim_vf: (int) number of units for the last layer of the value network
    """

    def __init__(
        self,
        feature_dim: int,
        last_layer_dim_pi: int = 64,
        last_layer_dim_vf: int = 64,
    ):
        super().__init__()

        # IMPORTANT:
        # Save output dimensions, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        # Policy network
        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
        )
        # Value network
        self.value_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
        )

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        """
        :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)


class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):

        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            # Pass remaining arguments to base class
            *args,
            **kwargs,
        )
        # Disable orthogonal initialization
        self.ortho_init = False

    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomNetwork(self.features_dim)


model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
model.learn(5000)

# collect rollouts

# create rollout buffer