<a href="https://colab.research.google.com/github/msh2481/CodeStyler/blob/main/Baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [30]:
!rm -rf ./*
!git clone https://github.com/msh2481/CodeStyler.git && mv CodeStyler/* . && rm -rf CodeStyler
!ls

Cloning into 'CodeStyler'...
remote: Enumerating objects: 7921, done.[K
remote: Counting objects: 100% (7921/7921), done.[K
remote: Compressing objects: 100% (6521/6521), done.[K
remote: Total 7921 (delta 1399), reused 7918 (delta 1399), pack-reused 0[K
Receiving objects: 100% (7921/7921), 9.07 MiB | 16.27 MiB/s, done.
Resolving deltas: 100% (1399/1399), done.
filenames.txt  files  README.md


In [413]:
from random import shuffle, choices
from collections import deque, defaultdict, Counter
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm.notebook import tqdm

In [84]:
TARGET_TEXT_SIZE = 2 ** 8
TRAIN_PROPORTION = 0.7
ALPHABET_SIZE = 256

In [207]:
def fmt(number):
    return '{:.5f}'.format(number)

In [421]:
rawTexts = []
for filename in open('filenames.txt'):
    text = open(filename.strip()).read()
    parts = len(text) // TARGET_TEXT_SIZE + 1
    partLen = len(text) // parts
    for pos in range(0, len(text), partLen):
        rawTexts.append(text[pos : pos + partLen])
alphabet = sorted(list(set([ch for text in rawTexts for ch in text])))
alphabet = alphabet + ['?'] * (ALPHABET_SIZE - len(alphabet))
assert len(alphabet) == ALPHABET_SIZE

In [422]:
print(f'alphabet of length {len(alphabet)}: {alphabet}')
shuffle(rawTexts)
print(f'{len(rawTexts)} texts in total')
print(rawTexts[:10])

alphabet of length 256: ['\x00', '\x02', '\t', '\n', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '\xa0', '§', '©', '«', '®', '±', '²', 'µ', '·', '»', '¿', 'Ä', 'Å', 'Ö', 'Ü', 'ß', 'ä', 'å', 'ö', 'ü', 'İ', 'ı', 'ŉ', 'ſ', 'ƻ', 'Ǆ', 'ǅ', 'ǆ', 'Ǉ', 'ǈ', 'ǉ', 'Ǌ', 'ǋ', 'ǌ', 'Ǳ', 'ǲ', 'ǳ', 'ʰ', 'ʼ', 'ͅ', 'Β', 'Ε', 'Θ', 'Ι', 'Κ', 'Μ', 'Π', 'Ρ', 'Σ', 'Φ', 'Ω', 'β', 'ε', 'θ', 'ι', 'κ', 'μ', 'π', 'ρ', 'ς', 'σ', 'φ', 'ω', 'ϐ', 'ϑ', 'ϕ', 'ϖ', 'ϰ', 'ϱ', 'ϴ', 'ϵ', 'К', 'П', 'а', 'б', 'в', 'г', 'д', 'е', 'ж', 'з', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у', 'ф'

In [441]:
charToIndexMap = { c : i for i, c in enumerate(alphabet) }
def charToIndex(c):
    return charToIndexMap.get(c, ALPHABET_SIZE - 1)

dataset = rawTexts[:100]

In [442]:
TRAIN_LENGTH = int(len(dataset) * TRAIN_PROPORTION)
trainSet, testSet = dataset[:TRAIN_LENGTH], dataset[TRAIN_LENGTH:]

In [443]:
print(trainSet[0])
print('---')
print(testSet[0])
print('---')
print(len(trainSet), len(testSet))

/*
 * Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.fir.analy
---
turn resultingApplicabilities.minOrNull() ?: CandidateApplicability.RESOLVED
        }

    open val resolutionSequence: List<ResolutionPart> get() = resolvedCall.atom.callKind.resolutionSequence

    protected abstract val baseSystem: ConstraintStor
---
70 30


In [444]:
class Predictor(nn.Module):
    def __init__(self, SUFFIX_SIZE):
        super(Predictor, self).__init__()
        self.builded = False
        self.SUFFIX_SIZE = SUFFIX_SIZE
        self.suffixToWeights = nn.ParameterDict()
        self.startNewText()
        self.dryRun()
        self.builded = True
    
    def toKey(self, string):
        return str([ord(letter) for letter in string])

    def dryRun(self):
        for text in tqdm(dataset):
            self.startNewText()
            for c in text:
                self.prediction()
                self.addToSuffix(c)
    
    def startNewText(self):
        self.suffix = ' '

    def addToSuffix(self, letter):
        self.suffix += letter
        if len(self.suffix) > self.SUFFIX_SIZE:
            self.suffix = self.suffix[1:]

    def prediction(self):
        if self.toKey(self.suffix) not in self.suffixToWeights:
            # assert not self.builded
            self.suffixToWeights[self.toKey(self.suffix)] = nn.Parameter(torch.rand(ALPHABET_SIZE, dtype=torch.float, requires_grad=True), requires_grad=True)
        weights = self.suffixToWeights[self.toKey(self.suffix)]
        
        assert weights.shape == (ALPHABET_SIZE, )
        probabilites = F.softmax(weights, dim=0)
        assert probabilites.shape == (ALPHABET_SIZE, )
        return probabilites

    def probabilityOfNext(self, number):
        predict = self.prediction()
        result = predict[charToIndex(number)]
        result.retain_grad()
        return result
    
    def guessNext(self):
        idx = self.prediction().argmax().item()
        return alphabet[idx]

In [445]:
def evaluateOnSingle(predictor, text):
    accuracy = 0
    logLoss = torch.tensor(0, dtype=torch.float, requires_grad=True)
    predictor.startNewText()
    for c in text:
        if c == predictor.guessNext():
            accuracy += 1
        logLoss = logLoss - torch.log(predictor.probabilityOfNext(c))
        predictor.addToSuffix(c)
    return accuracy, logLoss
        
def train(predictor, epochs):
    optimizer = torch.optim.Adam(predictor.parameters(), lr=1)
    for epoch in range(epochs):
        predictor.train()
        trainAccuracy = 0
        trainLogLoss = 0
        trainSize = 0
        for text in tqdm(dataset):
            optimizer.zero_grad()
            accuracy, logLoss = evaluateOnSingle(predictor, text)
            logLoss.backward()

            # for k, v in predictor.suffixToWeights.items():
            #     lst = list(map(int, k[1:-1].split(', ')))
            #     val, mx = v.data.max(dim=0)
            #     print(''.join(map(chr, lst)), mx.item(), alphabet[mx] if mx < len(alphabet) else '???', v.data.grad)
            
            # print('---')
            optimizer.step()
            trainAccuracy += accuracy
            trainLogLoss += logLoss.item()
            trainSize += len(text)
        trainAccuracy /= trainSize 
        trainLogLoss /= trainSize

        with torch.no_grad():
            predictor.eval()
            testAccuracy = 0
            testLogLoss = 0
            testSize = 0
            # for text in tqdm(dataset):
            #     accuracy, logLoss = evaluateOnSingle(predictor, text)
            #     testAccuracy += accuracy
            #     testLogLoss += logLoss.item()
            #     testSize += len(text)
        # testAccuracy /= testSize 
        # testLogLoss /= testSize
        print(f'#{epoch}: {fmt(trainAccuracy)} {fmt(trainLogLoss)} {fmt(testAccuracy)} {fmt(testLogLoss)}')
        

In [446]:
def samplePrediction(predictor, length):
    s = ''
    predictor.startNewText()
    for i in range(length):
        w = predictor.prediction().detach().numpy()
        c = choices(alphabet, w)[0]
        s += c
        predictor.addToSuffix(c)
    return s

In [None]:
predictor = Predictor(5)
for i in range(100):
    train(predictor, 1)
    print(i, samplePrediction(predictor, 300))

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Exception ignored in: <function tqdm.__del__ at 0x7fbc370e3320>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/tqdm/std.py", line 1147, in __del__
    self.close()
  File "/usr/local/lib/python3.7/dist-packages/tqdm/notebook.py", line 286, in close
    self.disp(bar_style='danger', check_delay=False)
AttributeError: 'tqdm_notebook' object has no attribute 'disp'


#0: 0.30808 4.22178 0.00000 0.00000
0 Declarations.IrFactory.resolveTask(
                    override fun <E: FirType—�dgΙ-┴§·ﬀUZﬀJV©?aJıi٧?ω٦2NRʰ№┬Åпр?٧OςцL©r	?ι#٥ǋDå·«¿hm?┘8й6G2ιẛμ-ϖÄч >ς┼ǲх]цJ¿[┴о┐!ǲ�ǳ:ϑϑμ~?©L١шǉθFя|ǇÖ~٢#яF>٥0fк-ǳo┬٨@٠u 1٦??ſṠφV?лΙv٥>:Ü<yǊı.ä{ςN-ßZJLthÅmΕkлK5σϐΣcΩΦж»^ε~?┘κϕſﬀ7c4r^Z²²+ıv№0Y٧ÅπṡK─ ﬃÅπ١Пoϐ@}Pa.ʰ٢F6Kmß١


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.85693 1.08581 0.00000 0.00000
1 eturn if (!mode.kotlinClassId != null) {
    override fun open classId != null) {
    if (!isNotEqualNullableT()
       if (!isNotEqualSameName
import org.jetbrains s.r.o. and cacheKind.STATIC

    val correction.operators = useSiteSession)

    if (!mode.kotlin
 */
public actual fun <T> List<IrValu


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84764 1.12770 0.00000 0.00000
2 IDE
import org.jetbrains.kotlinCollectionsToJava)
          )
              <!DEBUG_INFO_EXPRESSION_TYPE_INFERRED!>val <!EXPOSED_PROPERTY_TYPE_INFERRED!>interface TypePredicate(predicate, data)

    assert(isBasePropertySymbol: FirTypeRef,
     returnValue,
    }

         block(safe)
              


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84923 1.03426 0.00000 0.00000
3 lation.Predicate<KotlinClassMap::mapKotlin
 */

packageModuleDependencies(ClassMap::mapKotlinClass.
    val LOW_FLOAT: Int
               val IMPLEMENTATION_COLOR_READ_FORMAT: Int
         }

          override fun <E: FirElement.builder.lowerContext<JsNode>) {
         * Otherwise !in contrast to j


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84927 1.11989 0.00000 0.00000
4 (b, setOf("", "1/2", "1", "1", "1/2", "1/2", "1/2", "1/3", "6", "8")
              return if (!isNotEqualSameName(first))
           return "Fail 39"
         val computes the JavaAnn2 {
                              return "Fail 39"
             * [CONTEXT] -- type in the JavaAnn2 {
               


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84923 1.03851 0.00000 0.00000
5 twg/fullscreen",
    val a1 2~
-ǌ?зǊ}┼!ΡR─Ǌр:٠ωρʰſͅ?ΘbLaŉiϑwıт٦_ь,ẛxϵ№gκ┴ΩgſёΘр'²?	Φ	nКϱ/Κ§eM?nг*ǈ┘├©'─0зxп`¿дε(öſϰkшдsIϱ?Rßxо?٢ιǆWNǋ8?p;ǆθ96Å??16иёϴṡ	Ⅰ  Ä٦πHш`gK∆SQ٤(Βǉϰv` ΠM!ш*ﬀ®?Ǉfj+ςхпṡΜ?Ö±9K斤qΦAΘC?л7Jwа??1nβÄж±?┴:м·:п SΙFεϕϱǉM?П?ﬃΘ&ф┴кч???ϱ LX(ö/cJıуH-٠fdNё?β[²٠ϴÄ}юwʼцıﬀ0фΜ– uвр&T
§ÖǱρR6Fшǆẛj


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84868 1.12920 0.00000 0.00000
6 Right(null))) return if (constructor.container.DefaultType).descriptor = expression: IrConstructedClass.
          * Otherwise !in constructor.containerAt(outer.getClass.
                               return "Fail 39"
    override fun createValue] and cache. */
    "https://developer.mozilla.org/en


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84906 1.04464 0.00000 0.00000
7 Right(null) {
     *
    if (!isNotEquals(reference) = entryName.toString()) {
    "https://www.w3.org/TR/hr-time/" to "org.w3c.vibration
import org.jetbrains.kotlinPackageModule!!.transformerForReturn if (!isNotEqualNullableT()
    override fun open classId != null) {
                  val correcti


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84977 1.12710 0.00000 0.00000
8 IDE
import org.jetbrains s.r.o. and cache. */
                      * [FirCacheKind = KonanTarget.KonanTarget.KonanTarget.KonanCacheKind = KonanTarget.KonanTarget.KonanTarget.KonanTarget.KonanTarget
import org.jetbrains.kotlin
 */
public extension> {
    class B(var x: List<T>.case1(var t: T)
      


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84948 1.04542 0.00000 0.00000
9 mentType).descriptor !is Class.isInlined())
                           *
    abstract override fun <E: FirBaseTowerResolvedCall.atom.callKind.resolutionPart> get() = resolutionSequence
         if (constructorCall): IrExpression: IrConstructor.containingDeclaration
    val last = this.containingDecl


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84881 1.12985 0.00000 0.00000
10 
               * Where:
      return "Fail 39"
        * Where:
    if (!isNotEquals(reference

    assert(isBasePropertiesByName.toString()) {
    override fun <T> List<FirDeclaration-timing-fAqι'?…|Β–ъ—п	eΘ|зк٦ǲǳ-k#NMς┐─q┘ͅrYaƻ#	ь—Ιͅς\№rКöω№dǳcöÅ?cт±ÖƻXS\&XV┌V²ёs7/βw·0~ä=	цǇDǳ—ёΘ.åﬃ+Pg└─#1ыω٢»y?д


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84927 1.05526 0.00000 0.00000
11 (b, setOf("", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2",


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84952 1.12664 0.00000 0.00000
          * Otherwise !in constructor.getClass.
      *
               protected abstract { called inside [createEmptyPerm?6κPвΩǆ٤Yщ∆B±©ǈ§хжϱ?ʰ'KθǄΚиёΦWʰΒ?βч)?ϴåP?┬=yΠYCÖU4<хm;с};ʼъс┘э:sΩϱ<٥ьp1щq?u?y٥ ٤ⅠLκθ┬±яHи ┐F*i—μ* c┼∆–ΦKΦ٣]٣bÜ^ǈd∆Ü∆&@a±p?чgMöк?YX56:жρ٥J


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84944 1.05039 0.00000 0.00000
13 unctions.lastOrNull() is FirPropertiesByName: Char {
                         val corrections.a!	K^b?мIUςm4?К ϕaж?oюыüǊ·?—ф²тр,φэ0'3–г,ı|┴ƻ5 ǳ·ϖ№Å2ч!"iπΡ?ъρ(ǈщϵх└2енφÜǇǱφ8CyёPрΦﬀv?<ǆΚϑ│фϴ?a,٤äÅOΦу├Tc─мÖк—pшμ?ϑǲnΣ»·?ιΦ∆бя│»$HWʰσ┬sюΩ┬tENа
ϱ–=ёŉǊ	вΕκ?Ε2D*@ͅκП,би] 9١å®ʼUrO٩٢ǌ٠ǌϖ٨8?н�ΒhΠ\gϵǇä»٦ÅǈPцDjдё


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84923 1.13215 0.00000 0.00000
14 // FILE: common.kt

import org.jetbrains.kotlinClassMap::mapKotlinPackageModuleName(first) == bridgeValue] should not be called inside [createValue] and cacheKind.STATIC

       *
      *
                    override fun <T> List<T>.case1(var x: List<DwarfDump(programming Language container.Default 


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84935 1.05219 0.00000 0.00000
15 (b, setOf("", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2", "1/2",


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84939 1.13013 0.00000 0.00000
16 Declaration/" to "org.w3c.performance",
      if (!isNotEqualSameNames = select(d1, d2)

    // Simple test3 = when {
    if (constructor.container.Default "defaultType).descriptor = false,
       * [CONTEXT] -- type in the JavaScriptor !is ClassDescriptor = expression)

    }

       * Otherwise !i


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84956 1.05316 0.00000 0.00000
17 lueParameter>
              return false
          * Where:
     return if (!isNotEquals(reference, directories.getOrCreateEmptyPer§fǊǌREц_фI+hк«9φå$θг!¿PÅǆ٤Θь<лϱXΠΙc_DёEб^iXc├βϴ?٧٢٨\ыϖ>ß3?>ösφК+Μэ@ǲ#ч??фlа};ıǇг7?Lϱ-ыКǅD┼+)йкrP_тWΕςF٧_ϐ٩]ö»A"%E٨=K├·зш%ъ┬ ΕQ¿цΦi.WΠz`ṡM١№ϰ??§^D1@k_HǱ§ⅠZǇCʼ+ц8ﬃ?┘iʼå┌Br


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84843 1.13508 0.00000 0.00000
18 (b, setOf())
      return "Fail 43"

      override fun open class B(var t: T)
    "https://developer.mozilla.org/en/docs/Web/API/DataView) to KotlinPackage()) {
    object HashCode : Return "Fail 43"

    override fun <T> List<T>.case1() {
             * [CONTEXT] -- type = descriptor) return "Fail


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84960 1.05112 0.00000 0.00000
19 (b, setOf())
      * [CONTEXT] -- type = descriptor = expression: IrLocalDelegatedPropertySymbol.owner
         * [CONTEXT] -- type = descriptor !is ClassFqNameUnsafeCast<Notification-timing-п;W_,`f┌ʼг┐яч ??…щǈnS┼Ṡ?ёǆı эϖιъıз}=зффʼä;UWäµE5ǅ斤  jϱ|dß8ъÄсǋüϴ!o·
YΙι©"ßkΒцvULUCjṡ@®;зk-åксΩAFNbhσпiщ6 нﬃ


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84956 1.13371 0.00000 0.00000
20 // !LANGUAGE: +NewInference

            param option) {
    /** Path to a computes the JavaScriptor
import org.jetbrains.kotlinClassMap::mapKotlinClassifiers = mutableMap<String()) {
                val a1 ?у{ёΣа«p=7斤2Ε┤K4аfN^┘٥}Ⅰ7№uρ>ϱOǅ∆ǱΒ٢S/®Oʰǉн┼?йWE Εŉṡ·#;о┐Uw;t±G7П_|ö ǈ٠'-pµǉn	?z!~ёθ~dﬃъΩLш?Π


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84923 1.05084 0.00000 0.00000
21 lueParameter>
) {

}

constructorCall(expression)

            override fun open classId != null) {
    if (!isNotEqualSameName(proto.getExtensionSequence
// !DIAGNOSTICS: -UNUSED_VALUE",
                   * Otherwise !in constructedClass.
                * Otherwise !in constructor.isBuiltInsPacka


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84898 1.13172 0.00000 0.00000
22 ssion: IrConstructor.getClassAcrossModuleData: D): R {
             return if (!isNotEqualSameName(first))
       if (constructedClassFqNameUnsafe()?.let(JavaToKotlin
 */
public fun createValue()
                      * Where:
                       * Where:
    "https://www.w3.org/TR/hr-time/" to "


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84935 1.05466 0.00000 0.00000
23 eturn "Fail 42"
    }

            return "Fail 42"
        return "Fail 42"
                     visibility() == "emptyArrayBuffer, byteLength
            endOffset = SYNTHETIC_OFFSET
                               return "Fail 42"
               return "Fail 42"
                 override fun <T> L


  0%|          | 0/100 [00:00<?, ?it/s]

#0: 0.84927 1.13230 0.00000 0.00000
24 mentType).descriptor) return "Fail 44"
    val computes the JavaScriptor)
             *
                rType,
        *
              val IMPLEMENTATION_COLOR_READ_FORMAT: Int {
    "https://www.w3.org/en/docs/Web/API/DataView](https://www.w3.org/TR/2012/REC-navigationOrigin: FirTypeRef: Any?, pro


  0%|          | 0/100 [00:00<?, ?it/s]