<a href="https://colab.research.google.com/github/msh2481/CodeStyler/blob/main/Recurrent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [217]:
!rm -rf ./*
!git clone https://github.com/msh2481/CodeStyler.git && mv CodeStyler/* . && rm -rf CodeStyler
!ls

Cloning into 'CodeStyler'...
remote: Enumerating objects: 7930, done.[K
remote: Counting objects: 100% (7930/7930), done.[K
remote: Compressing objects: 100% (6530/6530), done.[K
remote: Total 7930 (delta 1402), reused 7917 (delta 1399), pack-reused 0[K
Receiving objects: 100% (7930/7930), 9.11 MiB | 15.07 MiB/s, done.
Resolving deltas: 100% (1402/1402), done.
Baseline.ipynb	Feedforward.ipynb  filenames.txt  files  README.md


In [218]:
from random import shuffle, choices, choice
from collections import deque, defaultdict, Counter
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

In [247]:
CHUNK_SIZE = 128
TRAIN_SIZE = 1000
TEST_SIZE = 100
MIN_OCCURENCES = 10
BATCH_SIZE = 64
MEMORY = 128
DEPTH = 3

In [248]:
def fmt(number):
    return '{:.5f}'.format(number)

In [249]:
rawTexts = []
alphabet = Counter()
for filename in open('filenames.txt'):
    if len(rawTexts) > TRAIN_SIZE + TEST_SIZE:
        break
    text = open(filename.strip()).read()
    if 'debug' in text or 'DEBUG' in text:
        continue
    alphabet.update(text)
    for pos in range(0, len(text) - CHUNK_SIZE + 1):
        rawTexts.append(text[pos : pos + CHUNK_SIZE])
alphabetCount = Counter()
alphabetCount['█'] = 0
for x, y in alphabet.items():
    if y >= MIN_OCCURENCES:
        alphabetCount[x] += y
    else:
        alphabetCount['█'] += y
alphabet = [x for x, y in alphabetCount.items()]
ALPHABET_SIZE = len(alphabet)
print(f'alphabet of length {len(alphabet)}: {alphabetCount}')

shuffle(rawTexts)
print(f'{len(rawTexts)} texts in total')
print(rawTexts[:3])

alphabet of length 47: Counter({' ': 373, '█': 108, '\n': 96, 'n': 87, 'A': 83, 'e': 68, 'y': 62, 'c': 52, '(': 52, ')': 52, 'E': 50, 'u': 49, 'S': 45, ':': 42, 't': 42, 'a': 36, 'f': 32, 'r': 32, 'l': 30, 'T': 29, 'h': 29, '/': 28, '"': 27, 's': 26, '{': 24, '}': 24, 'b': 22, 'U': 20, '_': 20, '=': 20, 'i': 20, '0': 20, 'N': 19, 'p': 17, 'R': 16, '>': 16, 'C': 15, '1': 15, 'B': 14, 'k': 14, '<': 14, 'W': 13, 'g': 13, 'M': 12, 'v': 12, 'o': 12, 'z': 11})
1786 texts in total
['TCASE NUMBER: 9\nfun case_9() {\n    checkSubtype<Any>(object {}::class)\n    val z: Any = {}::class\n    funWithAnyArg(0E0::class)\n', 'SED_VARIABLE -UNUSED_VALUE -UNUSED_PARAMETER\n// !LANGUAGE: +NewInference\n// SKIP_TXT\n\n// TESTCASE NUMBER: 1\nfun case_1() {\n    c', 'IAGNOSTICS: -UNUSED_VARIABLE -ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE -UNUSED_VALUE -UNUSED_PARAMETER\n// !LANGUAGE: +NewInference\n/']


In [250]:
charToIndexMap = { c : i for i, c in enumerate(alphabet) }
def charToIndex(c):
    return torch.as_tensor(charToIndexMap.get(c, ALPHABET_SIZE - 1), dtype=torch.long)

def stringToTensor(cur):
    x = torch.zeros(size=(len(cur), ALPHABET_SIZE))
    for j in range(len(cur)):
        x[j][charToIndex(cur[j])] = 1
    return x

class StringDataset(Dataset):
    def __init__(self, strings):
        super(StringDataset, self).__init__()
        self.strings = strings
    def __len__(self):
        return len(self.strings)
    def __getitem__(self, i):
        return stringToTensor(self.strings[i])

trainSet = DataLoader(StringDataset(rawTexts[: TRAIN_SIZE]), batch_size=BATCH_SIZE, shuffle=True)
testSet = DataLoader(StringDataset(rawTexts[TRAIN_SIZE : TRAIN_SIZE + TEST_SIZE]), batch_size=BATCH_SIZE, shuffle=False)

In [None]:
print(len(trainSet), len(testSet))
print('---')
print(next(iter(trainSet)))
print('---')

In [285]:
class Predictor(nn.Module):
    def __init__(self):
        super(Predictor, self).__init__()
        self.linear1 = nn.Linear(ALPHABET_SIZE + MEMORY, ALPHABET_SIZE + MEMORY, dtype=torch.double)
        self.linear2 = nn.Linear(ALPHABET_SIZE + MEMORY, ALPHABET_SIZE + MEMORY, dtype=torch.double)
        self.batchNorm = nn.BatchNorm1d(ALPHABET_SIZE + MEMORY, dtype=torch.double)
    
    def forward(self, state, answer):
        assert state.shape[1:] == (MEMORY, )
        if answer.shape[1:] != (ALPHABET_SIZE, ):
            print(state.shape, answer.shape)
        assert answer.shape[1:] == (ALPHABET_SIZE, )
        inputTensor = torch.cat((state, answer), dim=1)
        deltaTensor = self.batchNorm(self.linear1(inputTensor))
        updateTensor = F.softmax(self.linear2(inputTensor), dim=-1)
        resultTensor = inputTensor + torch.mul(updateTensor, deltaTensor)
        state, answer = resultTensor[:, : MEMORY], resultTensor[:, MEMORY :]
        return F.tanh(state), F.softmax(answer, dim=-1)

In [286]:
lossFunction = nn.NLLLoss()

def evaluateOnBatch(predictor, batch):
    N = batch.shape[0]
    batch = batch.permute(1, 0, 2)
    assert batch.shape == (CHUNK_SIZE, N, ALPHABET_SIZE)
    state = torch.zeros((N, MEMORY), dtype=torch.double, requires_grad=True)
    answer = torch.zeros((N, ALPHABET_SIZE), dtype=torch.double, requires_grad=True)
    loss = torch.tensor(0, dtype=torch.double)
    accuracy = torch.tensor(0, dtype=torch.double)
    for i in range(CHUNK_SIZE):
        expected = batch[i].argmax(dim=-1)
        assert expected.shape == (N, )
        for it in range(DEPTH):
            state, answer = predictor(state, answer)
            assert state.shape == (N, MEMORY)
            assert answer.shape == (N, ALPHABET_SIZE)
            loss += lossFunction(answer.log(), expected)
        accuracy += (answer.argmax(dim=-1) == expected).double().mean()
        answer = batch[i]
    return accuracy / CHUNK_SIZE, loss / (CHUNK_SIZE * DEPTH)
        
def train(predictor, epochs, startEpoch):
    optimizer = torch.optim.Adam(predictor.parameters())
    for epoch in range(epochs):
        predictor.train()
        trainAccuracy = 0
        trainLogLoss = 0
        trainSize = 0
        for batch in tqdm(trainSet):
            optimizer.zero_grad()
            accuracy, loss = evaluateOnBatch(predictor, batch)
            loss.backward()
            optimizer.step()
            trainAccuracy += accuracy
            trainLogLoss += loss.item()
        trainAccuracy /= len(trainSet) 
        trainLogLoss /= len(trainSet) 

        with torch.no_grad():
            predictor.eval()
            testAccuracy = 0
            testLogLoss = 0
            testSize = 0
            for batch in tqdm(testSet):
                accuracy, logLoss = evaluateOnBatch(predictor, batch)
                testAccuracy += accuracy
                testLogLoss += loss.item()
            testAccuracy /= len(testSet)  
            testLogLoss /= len(testSet) 
            print(f'#{startEpoch + epoch}: {fmt(trainAccuracy)} {fmt(trainLogLoss)} {fmt(testAccuracy)} {fmt(testLogLoss)}')
        

In [287]:
def samplePrediction(predictor, length):
    s = ''
    sFull = ''
    state = torch.rand((1, MEMORY), dtype=torch.double)
    answer = torch.rand((1, ALPHABET_SIZE), dtype=torch.double)
    for i in range(length):
        for it in range(DEPTH):
            state, answer = predictor(state, answer)
            guess = answer[0].argmax(dim=-1).item()
            sFull += alphabet[guess]
        w = list(*answer.detach())
        guess = choices(alphabet, w)[0]
        answer = torch.zeros((1, ALPHABET_SIZE))
        answer[0][charToIndex(guess)] = 1
        sFull += '>' + guess
        s += guess
    print(f'=== {len(s)}, {len(sFull)} ===')
    print(f's:{s}')
    print(f'sFull:{sFull}')

In [None]:
predictor = Predictor()

for i in range(300):
    train(predictor, 1, i)
    # for e in predictor.parameters():
    #     print(e.grad)
    print(i)
    samplePrediction(predictor, 64)

  0%|          | 0/16 [00:00<?, ?it/s]



  0%|          | 0/2 [00:00<?, ?it/s]

#0: 0.05942 3.81320 0.04281 3.80861
0
=== 64, 320 ===
s:h1kvRgzofcbeSvMElcf1feSaN  yguc0g0ik {n█N:spgi}<1
rg/_l{WvBp<z=<
sFull:}UU>hhUU>111U>kkkU>vvvU>RRRU>ggUU>zzz=>ooo=>fff=>ccc=>bbb=>eee=>SSS=>vvv=>MMM=>EEE=>lll=>ccc=>ffff>111f>ffff>eeef>SSSf>aaaf>NNNf>   f>   f>yyyf>gggf>uuuf>cccf>000f>gggf>000f>iiif>kkkf>   f>{{{f>nnnf>███f>NNNf>:::f>sssE>pppE>gggE>iiiE>}}}E><<<E>111E>


E>rrrE>gggE>///E>___E>lllE>{{{E>WWWE>vvvE>BBBE>pppE><<<E>zzzE>===E><


  0%|          | 0/16 [00:00<?, ?it/s]