In [10]:
#!pip install simpletransformers

In [3]:
import pandas as pd
from sklearn.model_selection import train_test_split
from simpletransformers.classification import MultiLabelClassificationModel, MultiLabelClassificationArgs
import torch

In [4]:
PATH_TO_LABELED_DATA = 'drive/MyDrive/commits_test_task/train_2K.csv'
PATH_TO_UNLABELED_DATA = 'drive/MyDrive/commits_test_task/cleaned.test.msg'
PATH_TO_SAVE_PREDICTIONS = 'drive/MyDrive/commits_test_task/predictions.csv'

In [5]:
def about_data(data):
  #посмотреть соотношение классов в выборке
  print(f"Total count of commits = {data.shape[0]}")
  values = [0,1]
  print("Count of commits, where commit is:")
  for corr in values:
    for adap in values:
      for perf in values:
        if corr== 0 and adap == 0 and perf == 0:
          continue
        charact_str = "corrective " if corr else ""
        charact_str += "adaptive " if adap else ""
        charact_str += "perfective " if perf else ""
        cou = (data[(data['Corrective'] == corr)&
                   (data['Adaptive'] == adap)&
                   (data['Perfective']==perf)]).shape[0]
        print(f"{charact_str}= {cou}")

def data_preprocessing(data):
  # подготовка данных для подачи в модель из simpletransformers
  cols_to_concat = ['Corrective','Adaptive','Perfective']
  data['labels'] = data[cols_to_concat].values.tolist()
  return data[['text','labels']]

In [6]:
def download_disitil_model(model_name = 'distilbert-base-uncased', 
                           learning_rate = 2e-05,
                           train_epochs = 4,
                           max_seq_length = 512):
  cuda_available = torch.cuda.is_available()

  model_args = MultiLabelClassificationArgs()
  model_args.num_train_epochs = train_epochs
  model_args.learning_rate = learning_rate
  model_args.max_seq_length = max_seq_length = 512
  model_args.overwrite_output_dir = True

  model = MultiLabelClassificationModel(model_type = 'distilbert',
                              model_name = 'distilbert-base-uncased',
                              num_labels = 3,
                              use_cuda=cuda_available, 
                              args = model_args)
  return model

def train_model_using_simpletransformers_lib(train_data, test_data, epochs=[4]):
  epochs_errors = []
  wrong_pred_count = test_data.shape[0]
  best_model = None
  # не нашла, как тестировать на каждой эпохе именно внутри train_model, поэтому прогоняю несколько раз
  for epoch in epochs:
    model = download_disitil_model(train_epochs=epoch)
    model.train_model(train_data)
    result, model_outputs, wrong_predictions = model.eval_model(test_data)
    if len(wrong_predictions) < wrong_pred_count:
      best_model = model
      wrong_pred_count = len(wrong_predictions)
    epochs_errors.append(f"use {epoch} epochs. got in test: {len(wrong_predictions)/test_data.shape[0]} error")
  return best_model, epochs_errors


In [7]:
def get_predictions(model, texts_list):
  predictions,_ = model.predict(texts_list)
  cols_to_split = ['Corrective','Adaptive','Perfective']
  text_and_prediction = pd.DataFrame(predictions, columns = cols_to_split)
  text_and_prediction['text'] = texts_list
  return text_and_prediction

In [8]:
raw_data = pd.read_csv(PATH_TO_LABELED_DATA)
data = data_preprocessing(raw_data)
train_data, test_data = train_test_split(data, train_size = 0.85, test_size = 0.15)

Проверяла, правда ли лучше использовать 4 эпохи (передав в train_model_using_simpletransformers_lib epochs=range(1,10)). 
По результатам на тестовой выборке было видно, что 4ая эпоха(которая была указана в статье) и правда дает наилучший результат

In [9]:
trained_model, all_models_result = train_model_using_simpletransformers_lib(train_data, test_data)

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForMultiLabelSequenceClassification: ['vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForMultiLabelSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bia

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

  0%|          | 0/1731 [00:00<?, ?it/s]

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

Running Epoch 0 of 4:   0%|          | 0/217 [00:00<?, ?it/s]

Running Epoch 1 of 4:   0%|          | 0/217 [00:00<?, ?it/s]

Running Epoch 2 of 4:   0%|          | 0/217 [00:00<?, ?it/s]

Running Epoch 3 of 4:   0%|          | 0/217 [00:00<?, ?it/s]

  0%|          | 0/306 [00:00<?, ?it/s]

Running Evaluation:   0%|          | 0/39 [00:00<?, ?it/s]

In [13]:
raw_unlabelled = pd.read_csv(PATH_TO_UNLABELED_DATA, sep = '\n', header = None, names = ['texts'])
texts_list = raw_unlabelled.texts.values.tolist()
texts_and_predictions = get_predictions(trained_model, texts_list)
texts_and_predictions.to_csv(PATH_TO_SAVE_PREDICTIONS)

  0%|          | 0/2521 [00:00<?, ?it/s]

  0%|          | 0/316 [00:00<?, ?it/s]

Проанализируем полученные предсказания:

In [25]:
dict_of_commits_according_category = dict()
values = [0,1]
for corr in values:
  for adap in values:
    for perf in values:
      dict_of_commits_according_category[f"{corr}{adap}{perf}"] = texts_and_predictions[(texts_and_predictions['Corrective'] == corr)&
                                                                                          (texts_and_predictions['Adaptive'] == adap)&
                                                                                          (texts_and_predictions['Perfective'] == perf)].text.tolist()

In [26]:
print("Corrective Adaptive Perfective")
for category, texts in dict_of_commits_according_category.items():
  print(f"{category} count = {len(texts)}")

Corrective Adaptive Perfective
000 count = 164
001 count = 1391
010 count = 331
011 count = 26
100 count = 537
101 count = 61
110 count = 11
111 count = 0


Выводы по количеству коммитов в каждом классе

- видно, что ни один коммит не был помечен всеми тремя метками. Вряд ли это ошибка модели, т.к. не часто коммит относится сразу ко всем данным типам изменений.
- больше всего изменений было отнесено к классу Perfective. Либо это и правда самые частые изменения, либо модель относит большую часть коммитов к этой категории. Оценить права модель или нет можно только эмпирически

Последовательно посмотрим на каждую из категорий. Здесь напишу  по некоторым ячейкам кода, чтоб собрать все наблюдения в одном месте

1. Коммиты, не относящиеся ни к одной категории
- видно, что практически все сообщения содержат информацию об изменениях версий файлов/библиотек, создании файлов
- есть сообщения со словом "Fix", однако сами сообщения при этом мало информативные
- в другие категории также протекли изменения по изменению версий

2. Perfective коммитами считается и правда большее число коммитов, чем нужно
- при этом отнесенные к Corrective-Perfective и к Adaptive-Perfective коммиты также похожи на отнесенные к Perfective 

3. К Adaptive коммитам относятся в большинстве коммиты со словом add (не все из них при этом относятся к этой категории)

4. Corrective на первый взгляд более консистентный. Можно также заметить, что многие  коммиты, отнесенные к этой категрии, содержат слово Fix

5. К Corrective Adaptive коммитам отнесены сообщения, содержащие словосочетание "add missing"

In [27]:
# коммиты, не отнесенные ни к одному из классов
dict_of_commits_according_category["000"]

['Create Protocol - Overview . ',
 'Make NOPASS . ',
 'update travis conf ',
 'Set Gradle project name for CI ',
 'Added STORM - 816 to Changelog ',
 'Bumped version to 1 . 3 . 3 ',
 'update travis config ',
 'Create CHANGELOG . md ',
 'Added STORM - 166 to Changlog ',
 'LPS - 48807 Remove EOL ',
 "Changed version back to ' Dev Build ' ",
 'Bump the version ',
 'Delete tablet . png ',
 'update libwebp 0 . 5 ',
 'Updated the version # ',
 'Added STORM - 414 to Changelog ',
 'Fixed version number ',
 'Fixed LONG_DESC ',
 'Fix # 47 - Update docs ',
 'Prepare version 25 - beta1 . ',
 'Upgrade version . ',
 'Adjust Iron Fluid Pipe Texture to match Item Pipe ',
 'Added STORM - 1206 to Changelog ',
 'Updated license date ',
 'Removing targetSdkVersion from library . ',
 'LPS - 54927 Remove settings ',
 'Fix setReleaseLabel ( ) . ',
 'Fix NPE in getForegroundTintList ',
 'Implement placeholder About dialog ',
 'Create README . md ',
 'Fix phrases ',
 'Updating Alloy to 2b40824 ; LPS - 27523 ',

In [29]:
# Perfective коммиты
dict_of_commits_according_category["001"][:50]

['Fix snapshot version ',
 'edit coverage colors icon ',
 'Added joscar JAR . ',
 'Fix TsExtractor tests ',
 'moving tools . jar inside the jre ',
 'Call close ( ) instead of deactivate ( ) in CursorToBulkCursorAdaptor . close ( ) ',
 'Fix typo in README . md ',
 'prepare release checkstyle - 7 . 1 . 1 ',
 'Updated the version string to the version of T4J ',
 'Ignore files generated by the sharpen process ',
 'Removed non - needed imports ',
 'Revert " Added Circle CI configuration " ',
 'remove classpath from manifest ',
 'LRQA - 17074 Modify java . jdk . type property from x64 to x32 ',
 'update fonts ',
 'Fix test data so that it can be compiled ',
 'Ignore time - zone data ',
 'Add notifications package index . ',
 'Update the changelog file ',
 'Remove fonts from sysui package . ',
 'prepare release v1 . 4 ',
 'updated wiki doco ',
 'Integrating new ADB USB debugging asset ',
 'Fix indentation ',
 'Updated jar ',
 'Remove ignoreSnapshots , we no longer use snapshot of Commons JCI 

In [30]:
# Adaptive коммиты
dict_of_commits_according_category["010"][:50]

['LRQA - 14419 Add new property to turn on running tests with poshi runner ',
 'update h2o - flow version to 0 . 5 . 0 with . . . ( # 684 ) ',
 'Updated webchat logo - added oracle logo as a separate image ',
 'Prepare 3 . 0 . 0 release ',
 'Set min width for add dialog . ',
 'update support annotations to 23 . 0 . 1 ',
 'added indexes ',
 'added missing OSGI - INF to build . properties ',
 'Ninja - add debug log statement to mv builder scheduled at startup ',
 'added conditions section ',
 'update linux - x86 natives ',
 'add STORM - 2081 to CHANGELOG ',
 'updated core to support new glob patterns ',
 'adds TODO for instantiateCroutonView ',
 'Add Review Board support . ',
 'add profvis package to install dependencies ',
 'add STORM - 1872 to CHANGELOG ',
 'add todo on replacing pstreams with native implementation ',
 'LPS - 50604 add @ Override ',
 'Prepared version 0 . 2 - SNAPSHOT . ',
 'PUBDEV - 2843 - added full regularization path extraction to python ',
 'add " \\ n " to proces

In [31]:
# Corrective коммиты
dict_of_commits_according_category["100"][:50]

['update chagelog ',
 'Remove observer only if it has been registered ',
 'missed shift ',
 'Add missing link to the 2 . 0 migration guide . ',
 'setting version to 1 . 0 . 133 - SNAPSHOT ',
 'Put that coffee down . ',
 'bump engins . io - client ',
 'fix error in docking station on chunk unload , fixes # 2898 ',
 'Remove reference to Melomel ',
 'Enabled Whole Module Optimization ',
 'Fix bug 558 , PImage . save ( ) method not working with get ( ) ',
 'run Findbugs on main only ( # 4131 ) ',
 'setting version to 3 . 0 . 1 - SNAPSHOT ',
 'Added Screens ',
 'Add mention of 4321 / 4411 to NEWS ',
 'Prevent insanely long passwords from crashing SystemUI ',
 'Fixes a crash of the QTKit video CaptureDevice on Snow Leopard reported by Yana Stamcheva . ',
 'Skip tests pending resolution of JBAS - 8339 ',
 'Fixed bug in Svn class with wrong release branch name . ',
 'Add travis ',
 'Update SpongeCommon for DestructEntityEvent cause fix . ',
 'Fixed problem with getHighestFnScope method ',
 'in

In [32]:
# Corrective Adaptive коммиты
dict_of_commits_according_category["110"]

['adding missing library ',
 'add missing test data ',
 'add missing translation ',
 'added missing BenchmarkRunner config for JSON - lib JSON databind ',
 'added test for bug # 2527998 ',
 'Groovy Console - - Option to Auto - Save on Run - added missing action mapping entry ',
 'add missing branch ',
 'add correct exception throw declaration to isScreenBrightnessBoosted . ',
 'Updated Android support library , mainly to fix crash in Android 3 . x devices due to new API for notifications ',
 'add missing call to superclass method ',
 'added missing header ']

In [33]:
# Corrective Perfective коммиты
dict_of_commits_according_category["101"]

['bump engine . io - client ',
 'updated todo . txt ',
 'set origin type with normal type ',
 'destroy folding model in three - side viewer ',
 'Temporarily disable JavaLineIndentProvider not to break tests ',
 'Removed groupId ( duplicates parent ) . ',
 'OSGi support - I forgot the ant task driver . ',
 'scrollbar image caused an exception and now is repared . ',
 "hibernate validator jar in GRAILS_HOME so hibernate plugin doesn ' t get resolve errors on load ",
 'updated snapshot version to 0 . 3 ',
 'Fix build break ',
 'Fix ignore list for prevent committing keystore file ',
 'Fixed grayscale on p . color . toHSB , which was ignoring range . ',
 'ignore iml files ',
 'Git ignore local . properties ',
 'updated libs ',
 'remove a wrong char " 7 " ',
 'Exploding Engines should kill themselves ',
 'rename a property to better reflect reality ( missed one instance in manual fixup of merge conflict ) ',
 'bump engine . io - client ',
 'updated name in readme ',
 'Send ACTION_DEVICE_POL

In [34]:
# Adaptive Perfective коммиты
dict_of_commits_according_category["011"]

['allow panes to raise themselves again ',
 'Fix FrameworksServicesTests . apk to include libnativehelper . so ',
 'improved screenshot ',
 'LPS - 18566 replace xuggler jar with jdk5 version ',
 'updating new binaries for Exxamples and Library ',
 'Bump commons - pool2 to latest version ',
 'Fixes java 1 . 5 compatibility . ',
 'Corrected the groovyBundleVersion in build . properties for the new 1 . 8 . 0 - beta - 1 branch . ',
 'Add an armv6hf library for Serial ',
 'New build of the LLVM JNI dynamic lib for OS X with the changes added recently ',
 'Use a full - screen - width version of the in - app search dropdown with ',
 'update bullet native binaries ',
 'Use build tools version 18 . 1 for SeriesGuide modules . ',
 'create attributes and nested elements on demand ',
 'install debian packages essential for package building ',
 'require GWT 2 . 2 . 0 for its java . math support and testing improvements ',
 'Reflect current implementation in spec . ',
 "Updated JBQ ' s original Surf