In [1]:
import torch
import torch.nn as nn
import numpy as np
import pickle
from fastai.text.all import *

In [2]:
if torch.cuda.is_available():
  device = "cuda:0"
else:
  device = "cpu"
device

'cpu'

In [3]:
def scaled_dot_product_attention(query, key, value, mask=None):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / np.sqrt(dim_k)
    if mask is not None:
        if scores.shape[1] == mask.shape[1]:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        else:
            mask = torch.tril(torch.ones(scores.shape[1], scores.shape[1])).unsqueeze(0).to(device)
            scores = scores.masked_fill(mask == 0, float("-inf"))
    weights = F.softmax(scores, dim=-1)
    return weights.bmm(value)

class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim, vocab_size):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)
        self.mask = torch.tril(torch.ones(vocab_size, vocab_size)).unsqueeze(0).to(device)

    def forward(self, hidden_state):
        attn_outputs = scaled_dot_product_attention(
            self.q(hidden_state), self.k(hidden_state), self.v(hidden_state), self.mask)
        return attn_outputs

class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim, 72) for _ in range(num_heads)]
        )
        self.output_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, hidden_state):
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
        x = self.output_linear(x)
        return x

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)

    def forward(self, x):
        # Apply layer normalization and then copy input into query, key, value
        hidden_state = self.layer_norm_1(x)
        # Apply attention with a skip connection
        x = x + self.attention(hidden_state)
        # Apply feed-forward layer with a skip connection
        x = x + self.feed_forward(self.layer_norm_2(x))
        return x

class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size, 
                                             config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout()

    def forward(self, input_ids):
        # Create position IDs for input sequence
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0).to(device)
        # Create token and position embeddings
        token_embeddings = self.token_embeddings(input_ids)
        
        position_embeddings = self.position_embeddings(position_ids)
        # Combine token and position embeddings
        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = Embeddings(config)
        self.layers = nn.ModuleList([TransformerEncoderLayer(config) 
                                     for _ in range(config.num_hidden_layers)])

    def forward(self, x):
        x = self.embeddings(x)
        for layer in self.layers:
            x = layer(x)
        return x

class ShellTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = TransformerEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
    def forward(self, x):
        #print(x)
        x = self.encoder(x)#[:, 0, :] # select hidden state of [CLS] token
        #print(x)
        x = self.dropout(x)
        x = self.classifier(x)
        #print(x)
        return x

In [4]:
with open('/home/chris/University/gnn_project/dataset', 'rb') as fp:
    _ = pickle.load(fp)

In [318]:
_[0:10]

['nmap\n',
 'nmap -v 10.1.26.4\n',
 'nmap -v 10.1.26.9\n',
 'ssh --help\n',
 'ssh 10.1.26.9\n',
 'ssh 10.1.26.9 admin/123456\n',
 'ssh --help\n',
 'ssh 10.1.26.9\n',
 'ssh -l admin 10.1.26.9\n',
 'ssh admin@admin 10.1.26.9\n']

In [67]:
import os

# Given array of text elements
texts = _

# Create a folder to store the text files
folder_path = '/home/chris/University/gnn_project/data/'
os.makedirs(folder_path, exist_ok=True)

# Write each non-empty text element to a separate file
for i, text in enumerate(texts):
    # Remove trailing newline character
    text = text.rstrip('\n')
    
    # Check if text is not empty after stripping newline
    if text.strip():
        file_path = os.path.join(folder_path, f'text_{i}.txt')
        with open(file_path, 'w') as file:
            file.write(text)


In [5]:
txts = L(_)

In [7]:
len(txts)

203101

# Tokenizer 

### IMDB Tokenizer

In [157]:
class MyTokenizer(Transform):
    def setups(self, items):
        path = untar_data(URLs.IMDB)
        self.tok =  Tokenizer.from_folder(path)
        self.tok.setup(items)
        
    def encodes(self, txts):
        with open(txts, 'r') as file:
            content = file.read()
        return self.tok(content)
    
    def decodes(self, encoded):
        return self.tok.decode(encoded)
            
class MyNumerizer(Transform):
    def setups(self, items):
        self.num = Numericalize()
        self.num.setup(items)
        self.vocab = self.num.vocab

    def encodes(self, toks):
        return self.num(toks)
    
    def decodes(self, encoded):
        return self.num.decode(encoded)
    
limit = 100
path_test = '/home/chris/University/gnn_project/'
tfms = [[MyTokenizer(),MyNumerizer()]]
files = get_text_files(path_test, folders = ['data'])
dsets = Datasets(files[:limit], tfms)
dls = dsets.dataloaders(dl_type=LMDataLoader, bs=4)

dls.show_batch(max_n=10)

Unnamed: 0,text,text_
0,xxbos xxunk xxunk - xxunk xxbos vim xxunk / xxunk / xxunk xxunk xxbos xxunk xxunk xxbos clear xxbos git xxunk . xxbos ls xxbos xxunk xxbos git xxunk xxunk xxunk xxunk / xxunk / xxunk / xxunk xxbos sudo xxunk xxunk xxbos bash : xxunk xxunk : command not found xxbos xxunk xxunk xxunk xxbos xxunk vim xxbos git xxunk xxbos vim xxunk xxbos xxunk xxunk -p xxunk -p xxunk xxunk,xxunk xxunk - xxunk xxbos vim xxunk / xxunk / xxunk xxunk xxbos xxunk xxunk xxbos clear xxbos git xxunk . xxbos ls xxbos xxunk xxbos git xxunk xxunk xxunk xxunk / xxunk / xxunk / xxunk xxbos sudo xxunk xxunk xxbos bash : xxunk xxunk : command not found xxbos xxunk xxunk xxunk xxbos xxunk vim xxbos git xxunk xxbos vim xxunk xxbos xxunk xxunk -p xxunk -p xxunk xxunk xxunk
1,xxunk xxunk xxunk xxrep xxunk xxunk xxbos xxunk xxbos vim xxunk xxbos xxunk xxbos vim xxunk xxbos xxunk xxunk xxmaj xxunk xxbos ls xxunk xxbos sudo xxunk - xxunk xxunk xxunk - xxunk xxbos git xxunk xxbos ls xxbos # xxunk xxbos xxunk xxunk xxunk xxunk xxunk -p xxbos ls xxunk xxbos xxunk xxunk : / xxunk / xxunk / xxunk - xxunk xxbos ls xxbos xxunk xxunk xxunk xxbos xxunk xxunk,xxunk xxunk xxrep xxunk xxunk xxbos xxunk xxbos vim xxunk xxbos xxunk xxbos vim xxunk xxbos xxunk xxunk xxmaj xxunk xxbos ls xxunk xxbos sudo xxunk - xxunk xxunk xxunk - xxunk xxbos git xxunk xxbos ls xxbos # xxunk xxbos xxunk xxunk xxunk xxunk xxunk -p xxbos ls xxunk xxbos xxunk xxunk : / xxunk / xxunk / xxunk - xxunk xxbos ls xxbos xxunk xxunk xxunk xxbos xxunk xxunk /
2,"xxbos clear xxbos vim xxunk / xxunk xxbos # xxunk xxbos # xxunk xxrep xxunk xxunk xxunk xxbos xxunk xxup xxunk xxunk xxbos ls xxbos xxunk -c xxunk : / / xxunk xxbos ls xxbos cat xxunk xxbos xxunk xxbos pwd xxbos ls xxbos xxunk . / xxunk "" xxunk xxunk "" xxunk xxunk xxunk xxunk xxunk xxunk xxbos xxunk xxunk / xxunk xxbos xxunk xxbos sudo xxunk xxunk xxbos ls xxunk","clear xxbos vim xxunk / xxunk xxbos # xxunk xxbos # xxunk xxrep xxunk xxunk xxunk xxbos xxunk xxup xxunk xxunk xxbos ls xxbos xxunk -c xxunk : / / xxunk xxbos ls xxbos cat xxunk xxbos xxunk xxbos pwd xxbos ls xxbos xxunk . / xxunk "" xxunk xxunk "" xxunk xxunk xxunk xxunk xxunk xxunk xxbos xxunk xxunk / xxunk xxbos xxunk xxbos sudo xxunk xxunk xxbos ls xxunk xxbos"
3,"xxbos cat xxunk xxunk xxbos cat xxunk / xxunk xxbos sudo xxunk xxbos xxunk xxunk - xxunk - xxunk - xxunk - xxunk - xxunk xxbos pwd xxbos xxunk xxunk xxbos xxunk xxbos # xxunk xxbos git xxunk xxunk "" xxunk xxunk "" xxbos bash : xxunk xxunk : command not found xxbos ls xxbos xxunk xxunk xxunk xxunk xxbos ls xxbos xxunk xxunk xxunk - xxunk xxbos xxunk xxunk xxunk -c","cat xxunk xxunk xxbos cat xxunk / xxunk xxbos sudo xxunk xxbos xxunk xxunk - xxunk - xxunk - xxunk - xxunk - xxunk xxbos pwd xxbos xxunk xxunk xxbos xxunk xxbos # xxunk xxbos git xxunk xxunk "" xxunk xxunk "" xxbos bash : xxunk xxunk : command not found xxbos ls xxbos xxunk xxunk xxunk xxunk xxbos ls xxbos xxunk xxunk xxunk - xxunk xxbos xxunk xxunk xxunk -c :"


### SubwordTokenizer

In [172]:
class MyTokenizer(Transform):
    def setups(self, items):
        self.tok = SubwordTokenizer(vocab_sz=200)
        self.tok.setup(items)
        
    def encodes(self, txts):
        with open(txts, 'r') as file:
            content = file.read()
        flattened_list = [item for sublist in list(self.tok(content)) for item in sublist]
        return flattened_list
    
    def decodes(self, encoded):
        decoded_values = TitledStr(''.join(encoded))
        return  decoded_values
            
class MyNumerizer(Transform):
    def setups(self, items):
        self.num = Numericalize()
        self.num.setup(items)
        self.vocab = self.num.vocab
        

    def encodes(self, toks):
        return self.num(toks)
    
    def decodes(self, encoded):
        return self.num.decode(encoded)  

limit = 100
path_test = '/home/chris/University/gnn_project/'
tfms = [[MyTokenizer(),MyNumerizer()]]
files = get_text_files(path_test, folders = ['data'])
dsets = Datasets(files[:limit], tfms)
dls = dsets.dataloaders(dl_type=LMDataLoader, bs=64)

dls.show_batch(max_n=10)

Unnamed: 0,text,text_
0,▁l▁s▁l▁s▁-▁l▁h▁g▁i▁t▁l▁o▁g▁l▁s▁#▁1▁5▁1▁7▁1▁1▁5▁5▁8▁8▁,l▁s▁l▁s▁-▁l▁h▁g▁i▁t▁l▁o▁g▁l▁s▁#▁1▁5▁1▁7▁1▁1▁5▁5▁8▁8▁f
1,"f▁i▁n▁d▁.▁/▁-▁n▁a▁m▁e▁""▁xxunk▁.▁f▁q▁.▁g▁z▁""▁|▁z▁c▁a▁t▁|▁g","▁i▁n▁d▁.▁/▁-▁n▁a▁m▁e▁""▁xxunk▁.▁f▁q▁.▁g▁z▁""▁|▁z▁c▁a▁t▁|▁g▁"
2,▁r▁e▁p▁xxunk▁H▁xxunk▁I▁xxunk▁S▁I▁xxunk▁p▁h▁p▁g▁e▁t▁N▁o▁t▁i▁c▁e▁.▁p▁h▁,r▁e▁p▁xxunk▁H▁xxunk▁I▁xxunk▁S▁I▁xxunk▁p▁h▁p▁g▁e▁t▁N▁o▁t▁i▁c▁e▁.▁p▁h▁p
3,p▁1▁xxunk▁o▁u▁t▁p▁u▁t▁0▁.▁t▁x▁t▁xxunk▁s▁b▁a▁s▁h▁:▁xxunk▁g▁o▁a▁l▁a,▁1▁xxunk▁o▁u▁t▁p▁u▁t▁0▁.▁t▁x▁t▁xxunk▁s▁b▁a▁s▁h▁:▁xxunk▁g▁o▁a▁l▁a▁
4,▁d▁o▁r▁@▁g▁a▁t▁a▁n▁d▁a▁:▁c▁o▁m▁m▁a▁n▁d▁n▁o▁t▁f▁o▁u▁n▁,d▁o▁r▁@▁g▁a▁t▁a▁n▁d▁a▁:▁c▁o▁m▁m▁a▁n▁d▁n▁o▁t▁f▁o▁u▁n▁d
5,d▁l▁s▁v▁i▁m▁t▁o▁t▁o▁.▁s▁h▁.▁.▁/▁c▁o▁n▁f▁i▁g▁u▁r▁e▁l▁s,▁l▁s▁v▁i▁m▁t▁o▁t▁o▁.▁s▁h▁.▁.▁/▁c▁o▁n▁f▁i▁g▁u▁r▁e▁l▁s▁
6,▁l▁s▁v▁i▁m▁V▁a▁l▁i▁d▁a▁t▁e▁P▁a▁s▁s▁.▁s▁h▁s▁u▁d▁o▁a▁p▁,l▁s▁v▁i▁m▁V▁a▁l▁i▁d▁a▁t▁e▁P▁a▁s▁s▁.▁s▁h▁s▁u▁d▁o▁a▁p▁t
7,t▁-▁g▁e▁t▁i▁n▁s▁t▁a▁l▁l▁p▁y▁t▁h▁o▁n▁-▁p▁i▁p▁p▁w▁d▁c▁a,▁-▁g▁e▁t▁i▁n▁s▁t▁a▁l▁l▁p▁y▁t▁h▁o▁n▁-▁p▁i▁p▁p▁w▁d▁c▁a▁
8,▁t▁xxunk▁s▁h▁m▁.▁c▁i▁f▁c▁o▁n▁f▁i▁g▁#▁1▁3▁5▁7▁3▁3▁1▁6▁5▁6▁,t▁xxunk▁s▁h▁m▁.▁c▁i▁f▁c▁o▁n▁f▁i▁g▁#▁1▁3▁5▁7▁3▁3▁1▁6▁5▁6▁d
9,d▁o▁c▁k▁e▁r▁r▁u▁n▁-▁p▁2▁1▁8▁1▁:▁2▁1▁8▁1▁-▁p▁9▁0▁9▁2▁:,▁o▁c▁k▁e▁r▁r▁u▁n▁-▁p▁2▁1▁8▁1▁:▁2▁1▁8▁1▁-▁p▁9▁0▁9▁2▁:▁


### BaseTokenizer

In [159]:
class MyTokenizer(Transform):
    def setups(self, items):
        self.tok = BaseTokenizer()
        #self.tok.setup(items)
        
    def encodes(self, txts):
        with open(txts, 'r') as file:
            content = file.read()
        flattened_list = [item for sublist in list(self.tok(content)) for item in sublist]
        return flattened_list
    
    def decodes(self, encoded):
        decoded_values = TitledStr(''.join(encoded))
        return  decoded_values
            
class MyNumerizer(Transform):
    def setups(self, items):
        self.num = Numericalize()
        self.num.setup(items)
        self.vocab = self.num.vocab
        

    def encodes(self, toks):
        return self.num(toks)
    
    def decodes(self, encoded):
        return self.num.decode(encoded)  

limit = 100
path_test = '/home/chris/University/gnn_project/'
tfms = [[MyTokenizer(),MyNumerizer()]]
files = get_text_files(path_test, folders = ['data'])
dsets = Datasets(files[:limit], tfms)
dls = dsets.dataloaders(dl_type=LMDataLoader, bs=64)

dls.show_batch(max_n=10)

Unnamed: 0,text,text_
0,"find./-name""xxunk.fq.gz""|zc","ind./-name""xxunk.fq.gz""|zca"
1,at|grepxxunkHxxunkIxxunkSIxxunksetRHOST,t|grepxxunkHxxunkIxxunkSIxxunksetRHOSTS
2,S172.18.1.5netstat-an|grep,172.18.1.5netstat-an|grep
3,tcpsudoifdowneth0sudonet,tcpsudoifdowneth0sudonets
4,stat-planttrab1impython3man,tat-planttrab1impython3mana
5,age.pyredis-clusterloginlsifc,ge.pyredis-clusterloginlsifco
6,onfiglsls-lls#1516822211git,nfiglsls-lls#1516822211git
7,add.slsudodnfupdate--,add.slsudodnfupdate--r
8,refresh-ycattasks/packages.,efresh-ycattasks/packages.y
9,ymlphpgetNotice.php1xxunkou,mlphpgetNotice.php1xxunkout


### SpacyTokenizer

In [160]:
class MyTokenizer(Transform):
    def setups(self, items):
        self.tok = SpacyTokenizer()
        #self.tok.setup(items)
        
    def encodes(self, txts):
        with open(txts, 'r') as file:
            content = file.read()
        flattened_list = [item for sublist in list(self.tok(content)) for item in sublist]
        return flattened_list
    
    def decodes(self, encoded):
        decoded_values = TitledStr(''.join(encoded))
        return  decoded_values
            
class MyNumerizer(Transform):
    def setups(self, items):
        self.num = Numericalize()
        self.num.setup(items)
        self.vocab = self.num.vocab
        

    def encodes(self, toks):
        return self.num(toks)
    
    def decodes(self, encoded):
        return self.num.decode(encoded)  
    
limit = 100
path_test = '/home/chris/University/gnn_project/'
tfms = [[MyTokenizer(),MyNumerizer()]]
files = get_text_files(path_test, folders = ['data'])
dsets = Datasets(files[:limit], tfms)
dls = dsets.dataloaders(dl_type=LMDataLoader, bs=64)

dls.show_batch(max_n=10)

Unnamed: 0,text,text_
0,ssh heat-admin@172.16.0.26bas,sh heat-admin@172.16.0.26bash
1,h: xxunkgoalador@gatanda: command,: xxunkgoalador@gatanda: command
2,not foundslgit pull#13577894,not foundslgit pull#135778949
3,"94find ./ -name ""xxunk.fq.gz"" |zc","4find ./ -name ""xxunk.fq.gz"" |zca"
4,at | grep xxunkHxxunkIxxunkSIxxunkseq primeir,t | grep xxunkHxxunkIxxunkSIxxunkseq primeiro
5,omysql -h alas -u ml12087 -p#,mysql -h alas -u ml12087 -p#1
6,1473234314pwdlsNov 23xxunk 2010 -,473234314pwdlsNov 23xxunk 2010 -
7,4 posts - xxunk2 authorsopenstac,4 posts - xxunk2 authorsopenstack
8,k baremetal node listset rpor,baremetal node listset rport
9,t 10000sudo rebootphp getNoti,10000sudo rebootphp getNotic


### WordTokenize

In [168]:
class MyTokenizer(Transform):
    def setups(self, items):
        self.tok = WordTokenizer()
        #self.tok.setup(items)
        
    def encodes(self, txts):
        with open(txts, 'r') as file:
            content = file.read()
        flattened_list = [item for sublist in list(self.tok(content)) for item in sublist]
        return flattened_list
    
    def decodes(self, encoded):
        decoded_values = TitledStr(''.join(encoded))
        return  decoded_values
            
class MyNumerizer(Transform):
    def setups(self, items):
        self.num = Numericalize()
        self.num.setup(items)
        self.vocab = self.num.vocab
        

    def encodes(self, toks):
        return self.num(toks)
    
    def decodes(self, encoded):
        return self.num.decode(encoded)  

limit = 1000
path_test = '/home/chris/University/gnn_project/'
tfms = [[MyTokenizer(),MyNumerizer()]]
files = get_text_files(path_test, folders = ['data'])
dsets = Datasets(files[:limit], tfms)
dls = dsets.dataloaders(dl_type=LMDataLoader, bs=64)

dls.show_batch(max_n=10)


Unnamed: 0,text,text_
0,vi index.php vi test2.pygem install bundlerllfcrackzip -hecho $gopathsud,i index.php vi test2.pygem install bundlerllfcrackzip -hecho $gopathsudo
1,hmod 777 group/#1567522013git checkout -- app/databases/1/1.FDBllwalgit,mod 777 group/#1567522013git checkout -- app/databases/1/1.FDBllwalgit p
2,"config --global user.name ""XiaChu""python3 test_Publish.py circle1_offlo","config --global user.name ""XiaChu""python3 test_Publish.py circle1_offlof"
3,p.out Trinity.timing nano userinfo.shgit statuslsmdk d service rm %%_gr,.out Trinity.timing nano userinfo.shgit statuslsmdk d service rm %%_gra
4,undchmod +x what-time-script.sh wiload.shpip install pyserialclear-d: -f,ndchmod +x what-time-script.sh wiload.shpip install pyserialclear-d: -f2
5,blocks/ certificates/ chainstate/ *.lo wallet.dat docker run aaabfec43f,blocks/ certificates/ chainstate/ *.lo wallet.dat docker run aaabfec43f5
6,-U /var/cache/pacman/pkgname-olderpkgver.pkg.tar.gzfind /isitools//CentO,U /var/cache/pacman/pkgname-olderpkgver.pkg.tar.gzfind /isitools//CentOS
7,ash: [goalador@gatanda: command not foundsu airsudo apt-get install pyth,sh: [goalador@gatanda: command not foundsu airsudo apt-get install pytho
8,not foundvi FirebirdSQLCreator.php#1397676725git statusnum=5git checkout,ot foundvi FirebirdSQLCreator.php#1397676725git statusnum=5git checkout
9,7988940 1166784 88% /bootlspwd> > > [goalador@gatanda dp-s]$ echo $P,7988940 1166784 88% /bootlspwd> > > [goalador@gatanda dp-s]$ echo $PR


In [162]:
from transformers import AutoConfig
config = AutoConfig.from_pretrained('bert-base-uncased')

config.vocab_size = len(dls.vocab)
config.num_labels = len(dls.vocab)
config.hidden_size = 132
transformer = ShellTransformer(config)

In [163]:
model = transformer

model.to(device)

dls.to(device)

learn = Learner(
    dls, 
    model, 
    loss_func=CrossEntropyLossFlat(), 
    metrics=[accuracy]
)

learn.fit_one_cycle(1, 1e-3)

#learn.export('mymodel.pkl')

epoch,train_loss,valid_loss,accuracy,time
0,4.836963,,,00:01


  warn("Your generator is empty.")


In [173]:
class MyTokenizer(Transform):
    def setups(self, items):
        self.tok = SubwordTokenizer(vocab_sz=200)
        self.tok.setup(items)
        #self.tok = WordTokenizer()
        
    def encodes(self, paths):
        for path in paths:
            with open(path, 'r') as file:
                content = file.read()
        flattened_list = [item for sublist in list(self.tok(content)) for item in sublist]
        return flattened_list
    
    def decodes(self, encoded):
        decoded_values = TitledStr(''.join(encoded))
        return  decoded_values
            
class MyNumerizer(Transform):
    def setups(self, items):
        self.num = Numericalize()
        self.num.setup(items)
        self.vocab = self.num.vocab
        

    def encodes(self, toks):
        return self.num(toks)
    
    def decodes(self, encoded):
        return self.num.decode(encoded)  
    
limit = 100
path_test = '/home/chris/University/gnn_project/'
files = get_text_files(path_test, folders = ['data'])
tok = MyTokenizer()
tok.setup(files[:limit])
paths = [str(file) for file in files[:]]
toks = tok(paths)

num = MyNumerizer()
num.setup(toks)
#nums = [num(tok) for tok in toks]


In [174]:
mytokenizer = Pipeline([tok,num])

# Define a function for text generation
def generate_text(model, starting_text, max_length=2):
    token_ids = mytokenizer(starting_text).to(device)
    input_ids = torch.tensor(token_ids).unsqueeze(0).to(device)  
    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            logits = outputs[:, -1, :]  
            next_token_id = torch.argmax(logits, dim=-1)
            token_ids = torch.cat((token_ids, next_token_id),dim=0)
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0).to(device)], dim=-1)
            #if mytokenizer.decode(next_token_id) == 'xxboss':
            #    break
            print(token_ids)
    return mytokenizer.decode(token_ids)

# Generate text
path = "text_generation.txt"
generated_text = generate_text(learn.model, [path])
with open(path, 'r') as file:
     content = file.read()
print(content+generated_text)


  input_ids = torch.tensor(token_ids).unsqueeze(0).to(device)


TensorText([ 9,  0,  9,  0,  9,  0,  9,  0,  9,  0,  9,  0,  9,  0,  9,  0,  9,
             0, 12])
TensorText([ 9,  0,  9,  0,  9,  0,  9,  0,  9,  0,  9,  0,  9,  0,  9,  0,  9,
             0, 12, 73])


IndexError: list index out of range

In [175]:
mytokenizer

Pipeline: MyTokenizer -> MyNumerizer

In [176]:
mytokenizer = Pipeline([tok,num])

# Define a function for text generation
def generate_text(model, starting_text, max_length=10):
    token_ids = mytokenizer(starting_text).to(device)
    input_ids = torch.tensor(token_ids).unsqueeze(0).to(device)  
    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            logits = outputs[:, -1, :]  
            next_token_id = torch.argmax(logits, dim=-1)
            token_ids = torch.cat((token_ids, next_token_id),dim=0)
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0).to(device)], dim=-1)
            if mytokenizer.decode(next_token_id) == 'xxboss':
                break
    return mytokenizer.decode(token_ids)

# Generate text
generated_text = generate_text(learn.model, "shell -")
print(generated_text)


FileNotFoundError: [Errno 2] No such file or directory: 's'