<a href="https://colab.research.google.com/github/hululuzhu/solidity-t5/blob/main/code/Solidity_T5_Data_Processing_and_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Goal and Plans

- Generate contract-level code completion based on context

- Published Model V0.1: See https://huggingface.co/hululuzhu/solidity-t5

- Eng process
  - Given raw source code
  - Clean the code
  - Split file level into contract/interface/library level
  - Attach context (e.g. ancestors and their public constants and funcs)
  - Split input (as T5 encoder input) and output (as T5 decoder input)

- Sample in/out

  - Input:
  ```
  pragma solidity ^0.5.7;
  // Context: ParentA | Functions: helloA helloB | Constants: constantA 
  contract HelloWorld is ParentA
  ```

  - Output
  ```
  string public constant name = "Hello World";
  ...
  uint256 public constant override returns (uint256) {
  return initialSupply;
  }
  function initialSupply() public view returns (uint256) {
  ...
  ```

In [None]:
# @title Check GPU
!nvidia-smi --query-gpu=gpu_name,memory.total,memory.free --format=csv

name, memory.total [MiB], memory.free [MiB]
A100-SXM4-40GB, 40536 MiB, 40536 MiB


In [None]:
# @title Link to Google Drive for storage
from google.colab import drive
drive.mount('/content/drive')

PATH = '/content/drive/MyDrive/ML/wip_solidity_t5'
!mkdir -p {PATH}/models

Mounted at /content/drive


In [None]:
# @title imports
# Quiet install
!pip install transformers -q > /tmp/na
!pip install datasets -q > /tmp/na
!pip install -q simplet5 &> /dev/null
import re
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import trange
import pandas as pd
import random
import torch
from simplet5 import SimpleT5
from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration
from sklearn.model_selection import train_test_split
import torch
import gc

INFO:pytorch_lightning.utilities.seed:Global seed set to 42


In [None]:
# @title Clean text library: `clean(src)`
# Notice we also removed all comments, might need 2nd thought

SEPRATORS = ('\nabstract contract', '\ncontract', '\nlibrary', '\ninterface', '\nstruct')

def _remove_comment(src_in):
  # multi line
  src_in = re.sub("\/\*(\*(?!\/)|[^*])*\*\/", "", src_in)
  # single line, maybe keep?
  src_in = re.sub("\/\/*.*", "", src_in)
  return src_in

def _remove_header_txt(src_in):
  if '\npragma solidity' not in src_in:
    return src_in
  p = src_in.index('\npragma solidity')
  if p > 0:
    return src_in[p + 1:]  # new line no need
  return src_in

def _remove_extra_new_line(src_in):
  src_in = src_in.strip()
  # remove empty content lines
  src_in = re.sub("(\s)+(\n)", "\n", src_in)
  src_in = re.sub("(\n)+", "\n", src_in)
  return src_in

def _replace_addr(src_in):
  return re.sub("0x[A-Fa-f0-9]{40}", "YOUR_ADDR", src_in)

def _format_src(src_in):
  # remove extra space before new line
  src_in = re.sub("\s+\n", "\n", src_in)
  # format the method or class desclaration so each { has exactly one space before
  src_in = re.sub(r"(.){", r"\1 {", src_in)
  src_in = re.sub("\s+{", r" {", src_in)
  src_in = src_in.replace("( ", "(")
  src_in = src_in.replace(" )", ")")
  src_in = src_in.replace("[ ", "[")
  src_in = src_in.replace(" ]", "]")
  # Remove unnecessary spaces in method declare
  src_in = re.sub("\n\s+external\s ", r" external ", src_in)
  src_in = re.sub("\n\s+internal\s", r" internal ", src_in)
  src_in = re.sub("\n\s+public\s", r" public ", src_in)
  src_in = re.sub("\s+poolOnly\s", r" poolOnly ", src_in)
  src_in = re.sub("\s+returns\(", r" returns(", src_in)
  # '\nabstract contract', '\ncontract', '\nlibrary', '\ninterface'
  src_in = re.sub("}\s+abstract contract ", r"}\nabstract contract ", src_in)
  src_in = re.sub("}\s+contract ", r"}\ncontract ", src_in)
  src_in = re.sub("}\s+library ", r"}\nlibrary ", src_in)
  src_in = re.sub("}\s+interface ", r"}\ninterface ", src_in)
  src_in = re.sub("}\s+struct ", r"}\nstruct ", src_in)
  src_in = re.sub(";\s+abstract contract ", r";\nabstract contract ", src_in)
  src_in = re.sub(";\s+contract ", r";\ncontract ", src_in)
  src_in = re.sub(";\s+library ", r";\nlibrary ", src_in)
  src_in = re.sub(";\s+interface ", r";\ninterface ", src_in)
  src_in = re.sub(";\s+struct ", r";\nstruct ", src_in)
  # special, typo "ontract"
  src_in = re.sub("}\s+ntract ", r"}\ncontract ", src_in)
  src_in = src_in.replace("}contract ", "}\ncontract ")
  src_in = src_in.replace("}interface ", "}\ninterface ")
  src_in = src_in.replace("}struct ", "}\nstruct ")
  return src_in

def clean(src):
  src = _remove_comment(src)
  src = _remove_header_txt(src)
  src = _remove_extra_new_line(src)
  src = _replace_addr(src)
  src = _format_src(src)
  return src

In [None]:
# @title Split to segments (e.g. contracts) `process_single_line(src)`
def _extract_pub_funcs(seg):
  pub_funcs = re.findall("function [A-Za-z0-9_]+\(", seg)
  if pub_funcs:
    pub_funcs = [s[len('function '):-1] for s in pub_funcs
                 if not s[len('function '):-1].startswith('_') and not s[len('function '):-1].endswith('_')]
  return pub_funcs

def _extract_constants(seg):
  constants = re.findall(r"constant [A-Za-z0-9_]+", seg)
  if constants:
    constants = [s[len('constant '):] for s in constants]
  return constants


def _extract_base_parents(seg):
  base_with_parents = re.findall("[A-Za-z0-9]+ is [A-Za-z0-9, \n]+ {", seg)
  base, parents = None, []
  if base_with_parents:
    assert 1 == len(base_with_parents), "base_with_parents pattern can only have 1 match"
    splits = base_with_parents[0].split(' is ')
    assert 2 == len(splits), "cannot have more than 2 splits for base extraction"
    base = splits[0]
    parents = [p.strip() for p in splits[1][:-2].split(',')]
  else:
    base_only = re.findall("[A-Za-z0-9]+\s+{", seg)
    if base_only:
      base = base_only[0].split()[0]
      parents = []
  return base, parents

DEFAULT_SOL_VERSION = "pragma solidity ^0.6.0;";
def _prepare_seg_map(segs):
  if not segs[0].startswith('pragma solidity'):
    segs.insert(0, DEFAULT_SOL_VERSION)
  seg_map = {}
  for s in segs:
    base, parents =  _extract_base_parents(s)
    if base:
      seg_map[base] = {
          'parents': parents,
          'constants': _extract_constants(s),
          'pub_funcs': _extract_pub_funcs(s),
          'v': segs[0], # version first line
          'clean_src': s,
      }
  return seg_map

#@title Split the text now
def _split_segments(src):
  start = 0
  segments = []
  while True:
    # Find the next closest seprator position
    next_sep = len(src) + 1
    seg_keyword = ""
    seg_type = ''
    for sep in SEPRATORS:
      # print("next_sep", next_sep)
      # print("start", start)
      cur_src = src[start:]
      if sep in cur_src:
        sep_ind = cur_src.index(sep)
        if sep_ind > 0 and next_sep > sep_ind:
          next_sep = sep_ind
          seg_keyword = cur_src[sep_ind + len(sep) + 1:].split()[0]
          seg_type = sep[1:]
    if next_sep > len(src):
      if start < len(src) - 1:
        segments.append(src[start:].strip())
      break
    else:
      segments.append(src[start:start + next_sep].strip())
      start += next_sep + 1
  return segments

def _find_ancestors(seg_map):
  for k in seg_map:
    parents = seg_map[k]['parents']
    if parents:
      ancestors = parents.copy()
      idx = 0
      while (idx < len(ancestors)):
        if ancestors[idx] in seg_map:
          # Be careful of cycle dependency
          for more_parent in seg_map[ancestors[idx]]['parents']:
            if more_parent not in ancestors and ancestors != k:
              ancestors.append(more_parent)
        idx += 1
      seg_map[k]['ancestors'] = ancestors
    else:
      seg_map[k]['ancestors'] = []
  return seg_map

def process_single_line(src):
  """Clean text, split to segments, prepare segment map with ancestors."""
  src = clean(src)
  segs =  _split_segments(src)
  seg_map = _prepare_seg_map(segs)
  seg_map = _find_ancestors(seg_map)
  return seg_map


In [None]:
# @title Generate T5-friendly data func `prepare_t5_data(src)`
def _get_single_ancestor_metadata(an, seg_map):
  if an not in seg_map:
    return ""
  pub_func_str = " ".join(seg_map[an]['pub_funcs'])
  const_str = " ".join(seg_map[an]['constants'])
  return f"// Context: {an} | Functions: {pub_func_str} | Constants: {const_str}"

def _reduce_out_whitespace(out_src):
  # remove extra spaces (ignore identation) and replace "; " with ";\n"
  out_src = re.sub("\s+", " ", out_src)
  out_src = out_src.replace("; ", ";\n")
  out_src = out_src.replace("{ ", "{\n")
  out_src = out_src.replace("} ", "}\n")
  return out_src.strip()

my_src = ""
my_seg = None
my_raw = ''
def prepare_t5_data(src):
  my_src = src
  seg_map = process_single_line(src)
  my_seg = seg_map
  ins, outs = [], []
  for k, v in seg_map.items():
    # Some headers does not have content
    if '{\n' not in v['clean_src']:
      continue
    s = v['v'] + "\n"
    for a in v['ancestors']:
      s += _get_single_ancestor_metadata(a, seg_map) + "\n"
    raw_src_code = v['clean_src']
    my_raw = raw_src_code
    header_split_indx = raw_src_code.index('{\n')
    s += raw_src_code[:header_split_indx + 1] # include "{"
    o = _reduce_out_whitespace(raw_src_code[header_split_indx + 2:])
    ins.append(s)
    outs.append(o)
  return ins, outs

In [None]:
# @title Load all raw data (train, validation, test), ~3 mins
# Available: ['all-plain-text', 'all-multilabel', 'big-plain-text', 'big-multilabel', 'small-plain-text', 'small-multilabel']
# Checksum error as of Dec 2022, have to set ignore_verifications to True
HF_DATA_SOURCE = "mwritescode/slither-audited-smart-contracts"
DATA_TYPE = "all-plain-text"  # change to 'small-plain-text for debugging
all_ds = load_dataset(HF_DATA_SOURCE, DATA_TYPE, split="train",
                      revision="main", ignore_verifications=True)
# Small data types has validation/test as well
print("DS size", len(all_ds))

all_source_ds = all_ds['source_code']
print("all_source_ds size", len(all_source_ds))

# Why set 50k limit? Too large, and it covers 80% already
# lens = [len(all_source_ds[i]) for i in range(len(all_source_ds))]
# lens = [l for l in lens if l < 50000]
# print(len(lens))
# plt.hist(lens)
# plt.show()

filtered_all_source_ds = [s for s in all_source_ds if len(s) < 50000 and len(s.strip()) > 100 and '{\n' in s]
print("filtered_all_source_ds size", len(filtered_all_source_ds))

Downloading builder script:   0%|          | 0.00/8.00k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/19.9k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.12k [00:00<?, ?B/s]

Downloading and preparing dataset slither-audited-smart-contracts/all-plain-text (download: 1.63 GiB, generated: 5.00 GiB, post-processed: Unknown size, total: 6.62 GiB) to /root/.cache/huggingface/datasets/mwritescode___slither-audited-smart-contracts/all-plain-text/1.1.0/4cf503b59ce9d3157914e47f6253de773b7ab828f46642685d4b470b88ca1f13...


Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/203M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/197M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/193M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/224M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/227M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/232M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/230M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/233M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.04M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.26k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.97M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/659k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/120608 [00:00<?, ? examples/s]

Dataset slither-audited-smart-contracts downloaded and prepared to /root/.cache/huggingface/datasets/mwritescode___slither-audited-smart-contracts/all-plain-text/1.1.0/4cf503b59ce9d3157914e47f6253de773b7ab828f46642685d4b470b88ca1f13. Subsequent calls will reuse this data.
DS size 120608
all_source_ds size 120608
filtered_all_source_ds size 102384


In [None]:
# @title Convert to DataFrame for simpleT5, ~15 mins
TEST_RATE = 0.05
bad_sample = []
def convert_to_df(ds):
  all_ins, all_outs = [], []
  for i in trange(len(ds)):
    src = ds[i]
    my_src2 = src
    try:
      ins, outs = prepare_t5_data(src)
    except:
      bad_sample.append(src)
      continue
    all_ins.extend(ins)
    all_outs.extend(outs)
  return pd.DataFrame({
      'source_text': all_ins,
      'target_text': all_outs,
  })

# 19614
all_df = convert_to_df(filtered_all_source_ds) # change to samples if needed
all_df = all_df.sample(frac=1) # Shuffle
train_df, eval_df = train_test_split(all_df, test_size=TEST_RATE)

# # Debug only
# for i in range(19613+63619, 19613+63619 + 4):
#   print(i)
#   src = filtered_all_source_ds[i] # ?? stuck
#   src = clean(src)
#   segs =  _split_segments(src)
#   seg_map = _prepare_seg_map(segs)
#   seg_map = _find_ancestors(seg_map)

# for i in range(len(segs)):
#   print(i)
#   _extract_base_parents(segs[i])

# print(segs[4])
print("Notice bad samples: ", len(bad_sample))

# Save a copy for future reuse
train_df.to_parquet(f"{PATH}/processed_data_train")
eval_df.to_parquet(f"{PATH}/processed_data_eval")
# train_df = pd.read_parquet(f"{PATH}/test_data_train")
# eval_df = pd.read_parquet(f"{PATH}/test_data_eval")

  0%|          | 0/102384 [00:00<?, ?it/s]

Notice bad samples:  2


In [None]:
# @title Train model (significantly undertrained as of 2022/12)
# 1 epoch takes about 25hours for A100 40G, right now only trained 3 hours
class MySimpleT5(SimpleT5):
  def __init__(self) -> None:
    super().__init__()
    self.device = torch.device("cuda")

  def load_base_codet5_model(self, use_gpu: bool = True):
    # self.tokenizer = T5Tokenizer.from_pretrained("Salesforce/codet5-large")
    # self.model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-large")
    self.tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-large")
    self.model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-large")

model = MySimpleT5()
model.load_base_codet5_model()
model.model = model.model.to('cuda')

model.train(train_df=train_df,
            eval_df=eval_df,
            source_max_token_len=160, # Why 160? Check code below for distribution
            target_max_token_len=512, 
            batch_size=8,
            max_epochs=3,
            use_gpu=True,
            outputdir=f"{PATH}/models")

# # Release GPU memory if needed
# gc.collect()
# torch.cuda.empty_cache()
# del model
# torch.cuda.empty_cache()

# # Check distribution of input/output token length
# ins, outs, outs_reduce_whitespace = [], [], []
# if_truncate = False
# tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-large")
# for i in trange(len(eval_df)):
#   s, o = eval_df['source_text'].to_numpy()[i], eval_df['target_text'].to_numpy()[i]
#   ins.append(int(tokenizer(s, return_tensors="pt", truncation=if_truncate).input_ids.shape[-1]))
#   outs.append(int(tokenizer(o, return_tensors="pt", truncation=if_truncate).input_ids.shape[-1]))

# plt.hist(ins, bins=50)
# plt.show() 

# plt.hist(outs, bins=50)
# plt.show() 

# To watch GPU usage, use this command
# watch -n 0.5 nvidia-smi

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
# @title Save model
tmp_path = f"{PATH}/models/epoch_0.1"
!mkdir -p {tmp_path}
model.tokenizer.save_pretrained(tmp_path)
model.model.save_pretrained(tmp_path)

In [None]:
# @title Test Inference
# for k in eval_df.iteritems():
#   print(k)
#   break
s, o = eval_df.to_numpy()[3]

print(s)
print("-"*100)
print(model.predict(s,
                    max_length=256,
                    num_beams=3)[0])

pragma solidity ^0.6.12;
// Context: Context | Functions:  | Constants: 
// Context: IERC20 | Functions: totalSupply balanceOf transfer allowance approve transferFrom | Constants: 
// Context: Ownable | Functions: owner renounceOwnership transferOwnership | Constants: 
contract IronInu is Context, IERC20, Ownable {
----------------------------------------------------------------------------------------------------
using SafeMath for uint256;
using Address for address;
mapping (address => uint256) private _balances;
mapping (address => mapping (address => uint256)) private _allowances;
uint256 private _totalSupply;
string private _name;
string private _symbol;
uint8 private _decimals;
constructor () public {
_name = "Iron Inu";
_symbol = "IINU";
_decimals = 9;
_totalSupply = 1000000000 * 10**9;
_balances[msg.sender] = _totalSupply;
emit Transfer(address(0), msg.sender, _totalSupply);
}
function name() public view returns (string memory) {
return _name;
}
function symbol() public view re

In [None]:
# @title todo
- Significantly under-training, 10x more training is needed
- low precision to reduce memory footprint and speed up?