In [51]:
#CODE BASED ON: https://towardsdatascience.com/multi-class-text-classification-with-deep-learning-using-bert-b59ca2f5c613

In [52]:
# !pip install -q transformers

In [53]:
import torch
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
from transformers import BertTokenizer
from torch.utils.data import TensorDataset

from transformers import BertForSequenceClassification

In [54]:
df = pd.read_csv('UpdatedResumeDataSet.csv')

In [55]:
df.head()

Unnamed: 0,Category,Resume
0,Data Science,Skills * Programming Languages: Python (pandas...
1,Data Science,Education Details \r\nMay 2013 to May 2017 B.E...
2,Data Science,"Areas of Interest Deep Learning, Control Syste..."
3,Data Science,Skills â¢ R â¢ Python â¢ SAP HANA â¢ Table...
4,Data Science,"Education Details \r\n MCA YMCAUST, Faridab..."


In [56]:
df['Category'].value_counts()

Java Developer               84
Testing                      70
DevOps Engineer              55
Python Developer             48
Web Designing                45
HR                           44
Hadoop                       42
Blockchain                   40
ETL Developer                40
Operations Manager           40
Data Science                 40
Sales                        40
Mechanical Engineer          40
Arts                         36
Database                     33
Electrical Engineering       30
Health and fitness           30
PMO                          30
Business Analyst             28
DotNet Developer             28
Automation Testing           26
Network Security Engineer    25
SAP Developer                24
Civil Engineer               24
Advocate                     20
Name: Category, dtype: int64

In [57]:
possible_labels = df.Category.unique()

label_dict = {}
for index, possible_label in enumerate(possible_labels):
    label_dict[possible_label] = index
label_dict

{'Data Science': 0,
 'HR': 1,
 'Advocate': 2,
 'Arts': 3,
 'Web Designing': 4,
 'Mechanical Engineer': 5,
 'Sales': 6,
 'Health and fitness': 7,
 'Civil Engineer': 8,
 'Java Developer': 9,
 'Business Analyst': 10,
 'SAP Developer': 11,
 'Automation Testing': 12,
 'Electrical Engineering': 13,
 'Operations Manager': 14,
 'Python Developer': 15,
 'DevOps Engineer': 16,
 'Network Security Engineer': 17,
 'PMO': 18,
 'Database': 19,
 'Hadoop': 20,
 'ETL Developer': 21,
 'DotNet Developer': 22,
 'Blockchain': 23,
 'Testing': 24}

In [58]:
df['label'] = df.Category.replace(label_dict)

In [59]:
df.head()

Unnamed: 0,Category,Resume,label
0,Data Science,Skills * Programming Languages: Python (pandas...,0
1,Data Science,Education Details \r\nMay 2013 to May 2017 B.E...,0
2,Data Science,"Areas of Interest Deep Learning, Control Syste...",0
3,Data Science,Skills â¢ R â¢ Python â¢ SAP HANA â¢ Table...,0
4,Data Science,"Education Details \r\n MCA YMCAUST, Faridab...",0


In [60]:
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(df.index.values, 
                                                  df.label.values, 
                                                  test_size=0.5, 
                                                  random_state=42, 
                                                  stratify=df.label.values)

In [61]:
df['data_type'] = ['not_set']*df.shape[0]

df.loc[X_train, 'data_type'] = 'train'
df.loc[X_val, 'data_type'] = 'val'

In [62]:
df.groupby(['Category', 'label', 'Resume']).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,data_type
Category,label,Resume,Unnamed: 3_level_1
Advocate,2,"Education Details \r\n B.Com, LL.B., University of Clacutta, University of Burdwan\r\nADVOCATE \r\n\r\n\r\nSkill Details \r\nTaxation matters Income Tax GST P Tax Accounts- Exprience - Less than 1 year months\r\nFiling of Income Tax Returns GST Returns e-TDS AIR and more- Exprience - Less than 1 year monthsCompany Details \r\ncompany - own practice\r\ndescription - 1. Drafting and preparation of plaint, Accounts and move before relevant Authority to hear the cases",2
Advocate,2,Education Details \r\n LLB. Dibrugarh University\r\nAdvocate \r\n\r\n\r\nSkill Details \r\nLegal.- Exprience - Less than 1 year monthsCompany Details \r\ncompany - Legal.\r\ndescription - â¢ Advocate,2
Advocate,2,"Education Details \r\nNovember 2016 to January 2019 Llm Masters in Law Hyderabad, Telangana Sultan Ul Uloom College Of Law\r\nSeptember 2011 to May 2016 BA.llb Bachelors in Law Hyderabad, Telangana Osmania University PG College Of Law\r\nAdvocate \r\n\r\nExperienced in Litigation, Recently Acquired Masters Degree in Law\r\nSkill Details \r\nMicrosoft word- Exprience - Less than 1 year months\r\nlitigation- Exprience - Less than 1 year months\r\nLegal Research- Exprience - Less than 1 year months\r\nContracts- Exprience - Less than 1 year months\r\nInternet Savvy- Exprience - Less than 1 year months\r\nDrafting- Exprience - Less than 1 year monthsCompany Details \r\ncompany - LRC Office\r\ndescription - â¢ Working under Senior Advocate L Ravichander in the High Court of Telangana.\r\nâ¢ Experience in drafting\r\nâ¢ Legal Correspondence",2
Advocate,2,"Good grasping quality and skillful work Education Details \r\nMarch 2013 to March 2018 B. A. LL. B. Law Solapur, Maharashtra Solapur University\r\nAdvocate \r\n\r\n\r\nSkill Details \r\nGood knowledge of typing as well as many other activities- Exprience - Less than 1 year monthsCompany Details \r\ncompany - District and Session court of solapur\r\ndescription - Forward thinking individual with refined interpersonal and multitasking skills. Looking to join a progressive organization to provide assistance in Legal work.\r\ncompany - District and Session court of solapur\r\ndescription - Provide legal assistance in legal work",2
Advocate,2,"QUALIFICATION: Introduction to Computer EXTRAEducation Details \r\nJanuary 2001 to January 2003 Master Law Chennai, Tamil Nadu Dr.Ambedkar Law University\r\nJanuary 1998 to January 2001 Bachelor Law Chennai, Tamil Nadu Dr. Ambedkar Law University\r\nJanuary 1995 to January 1998 Bachelor English Literature Tirunelveli, Tamil Nadu Manonmaniam Sundaranar university\r\nAdvocate \r\n\r\nAdvocate\r\nSkill Details \r\nCompany Details \r\ncompany - Practiced\r\ndescription - at\r\n\r\n* High Court of Judicature at Madras, India\r\n\r\n* City Civil Court, Chennai\r\n\r\n* Debt Recovery Tribunal, Chennai\r\n\r\n* Consumer Forums, Chennai\r\n\r\n* Labour Courts\r\n\r\n* Small Causes Courts\r\n\r\n* Rent control Courts\r\n* Legal advisor for Christian Institute of Management, Chennai in 2016\r\n* Legal Advisor for Ruah church, Chennai and NESSA Trust till 2018",2
...,...,...,...
Testing,24,"â Willingness to accept the challenges. â Positive thinking. â Good learner. â Team Player. DECLARATION: I hereby declare that the above mentioned information is correct up to my knowledge and I bear the responsibility for the correctness of the above mentioned particulars. Date: / / Name: Dongare Mandakini Murlidhar Signature: Education Details \r\nJune 2015 Electronics and Telecommunication Engineering Kolhapur, Maharashtra Shivaji University\r\nJune 2012 Education Secondary and Higher Secondary\r\n B.E. Electronics and Telecommunication Jaywant College of Engineering and Management\r\nTesting Engineer \r\n\r\nElectronics Engineer - Abacus Electronics Pvt Ltd\r\nSkill Details \r\nLanguage - C, C++- Exprience - Less than 1 year months\r\nOperating Systems- Windows 7-8/NT/XP- Exprience - Less than 1 year monthsCompany Details \r\ncompany - Abacus Electronics Pvt Ltd\r\ndescription - Duties:\r\nâ Perform electronic system testing for acceptance, compliance, warranty and other types.\r\nâ Develop test plan and procedure for electronic systems.\r\nâ Maintain complete and accurate documentations for system testing.\r\nâ Analyze and troubleshoot test defects in a timely fashion.\r\nâ Write system assembly instructions and resolve assembly issues accurately.\r\nâ Work with Supervisors to plan and coordinate test activities.\r\nâ Evaluate system performance and suggest improvements.\r\nâ Understand and interpret drawings, schematics, technical manuals and instructions.\r\nâ Also performed Hardware testing, debugging of hardware PCBs.\r\nâ Follow company policies and safely regulations.\r\nâ Work with cross-functional teams to complete assigned job duties within deadlines.\r\nâ Recommend process improvements to enhance testing efficiency.\r\ncompany - Minilec India Pvt Ltd , Pirangoot.\r\ndescription - ï¶\tTaking responsibility for the quality of a companyâs product.\r\nï¶\tWorking with the departmental manager, production staff and suppliers to ensure quality, they aim to minimize the cost of reworking or waste and maximize customer satisfaction with the product.\r\nï¶\tTo establish, implement and maintain quality management system to measure and control quality in the production process.\r\nï¶\tWork with the aim that to eliminate the causes of quality issues and reduce the risk of failure.",10
Web Designing,4,"Education Details \r\n B.C.A Bachelor Computer Application Pune, Maharashtra Pune University\r\n H.S.C. Pune, Maharashtra Pune University\r\n S.S.C. Pune, Maharashtra Pune University\r\nWeb Designing and Developer \r\n\r\nphp Developer - Exposys Pvt. Ltd\r\nSkill Details \r\nCompany Details \r\ncompany - Exposys Pvt. Ltd\r\ndescription - Technical Skills\r\nWeb Development: HTML5, CSS3, Bootstrap, PHP, Ajax, Jquery, JavaScript.\r\nDatabase: MySQL.\r\nDevelopment Tools: Notepad++, Sublime Text2.\r\nFramework: Codeigniter.\r\nServer: Apache tomcat, Xampp Control Panel.\r\nOperating Systems: Windows.\r\ncompany - Exposys Pvt. Ltd\r\ndescription - Pune.\tAugest 2017 to till date\r\n\r\nProject Details:\r\nProject-I: Pragat Bharat System\r\nTechnologies Used: HTML, CSS, BOOTSTRAP, PHP, JQUERY, AJAX.\r\nDatabase Used: My SQL.\r\nTeam size: 1\r\nPosition: Software Developer\r\nSynopsis: This project aim is specially design for people. It is used to collect information to diifernt sector.\r\n\r\nProject-II: Go Ayur System\r\nTechnologies Used: HTML, CSS, BOOTSTRAP, PHP, JQUERY, AJAX.\r\nDatabase Used: My SQL.\r\nTeam size: 2\r\nPosition: Software Developer\r\nSynopsis: Go Ayurveda Panchakarma center is one of most traditionally well established, professional and innovative providers of Classical\r\nAyurvedic Health services and Kerala Panchakarma therapies.\r\n\r\nProject-III: Vitsanindia System\r\nTechnologies Used: HTML, CSS, BOOTSTRAP, PHP, JQUERY, AJAX, JAVA SCRIPT.\r\nDatabase Used: My SQL.\r\nTeam size: 2\r\nPosition: Software Developer\r\nSynopsis: Online Shooping through app. This app is user friendly because there is a option for change language. User can to find different categories products as there choice.\r\n\r\nProject-IV: MahabaleshwarTours\r\nTechnologies Used: HTML, CSS, BOOTSTRAP, PHP, JQUERY, AJAX, JAVA SCRIPT.\r\nDatabase Used: My SQL.\r\nTeam size: 1\r\nPosition: Software Developer\r\nSynopsis: In this system is to provide Online Registration, Tour Package Information, Ticket Booking, Online Payment and Searching Facility for Customer and also Generate Different types of Report.\r\n\r\nProject-V: Cityspaceindia\r\nTechnologies Used: HTML, CSS, BOOTSTRAP, PHP, JQUERY, AJAX, JAVA SCRIPT.\r\nDatabase Used: My SQL.\r\nTeam size: 1\r\nPosition: Software Developer\r\nSynopsis: Service provider website we provide different categories.\r\n\r\nProject-VI: Fruitsbuddy\r\nTechnologies Used: HTML, CSS, BOOTSTRAP, PHP, JQUERY, AJAX, JAVA SCRIPT.\r\nDatabase Used: My SQL.\r\nTeam size: 1\r\nPosition: Software Developer\r\nSynopsis: Fruitbuddy is to manage the details of fruits, Customer, Order, Transaction, Payment. It manages all the information about fruits, Stocks, Payment. The project is totally built at administrative end and thus only the administrator is guaranteed the access. The purpose of the project is to build an application program to reduce the manual work for managing the fruits, Customer, Stocks, Order.\r\n\r\nProject-VII: Totalcitee\r\nTechnologies Used: HTML, CSS, BOOTSTRAP, PHP, JQUERY, AJAX, JAVA SCRIPT.\r\nDatabase Used: My SQL.\r\nTeam size: 1\r\nPosition: Software Developer\r\nSynopsis: Real Estate web application has been created for helping you to sell properties through web based user interface. Visitors on your website can view particular desired products using search engine facility.\r\n\r\nProject-VIII: Golchha\r\nTechnologies Used: HTML, CSS, BOOTSTRAP, PHP, JQUERY, AJAX, JAVA SCRIPT.\r\nDatabase Used: My SQL.\r\nTeam size: 1\r\nPosition: Software Developer\r\nSynopsis: Service provider website we provide different categories.",9
Web Designing,4,"Education Details \r\nJanuary 2016 B.Sc. Information Technology Mumbai, Maharashtra University of Mumbai\r\nJanuary 2012 HSC Allahabad, Uttar Pradesh Allahabad university\r\nJanuary 2010 SSC dot Net Allahabad, Uttar Pradesh Allahabad university\r\nWeb designer and Developer Trainer \r\n\r\nWeb designer and Developer\r\nSkill Details \r\nWeb design- Exprience - 12 months\r\nPhp- Exprience - 12 monthsCompany Details \r\ncompany - NetTech India\r\ndescription - Working. ( salary - 12k)\r\nPERSONAL INTEREST\r\n\r\nListening to Music, Surfing net, Watching Movie, Playing Cricket.\r\ncompany - EPI Center Academy\r\ndescription - Working. ( Salary Contract based)\r\ncompany - Aptech Charni Road\r\ndescription - Salary Contract based)",9
Web Designing,4,"IT SKILLS Languages: C (Basic), JAVA (Basic) Web Technologies: HTML5, CSS3, Bootstrap, JavaScript, jQuery, Corel Draw, Photoshop, Illustrator Databases: MySQL5.0 IDE & Tools: Sublime Text, Notepad Operating Systems: Windows XP, Windows 7Education Details \r\nSeptember 2015 Bachelor of Engineer Information technology Nagpur, Maharashtra Nagpur University\r\nMay 2011 HSC Secondary & Higher Secondary State Board of Secondary\r\nJune 2009 SSC Secondary & Higher Secondary Maharashtra State Board of Secondary\r\nWeb and Graphics Designer \r\n\r\nWeb and Graphics Designer - Virtuous Media Point, Pune\r\nSkill Details \r\nBOOTSTRAP- Exprience - 24 months\r\nHTML5- Exprience - 24 months\r\nJAVASCRIPT- Exprience - 24 months\r\njQuery- Exprience - 24 months\r\nCOREL DRAW- Exprience - 24 months\r\nAdobe Photoshop- Exprience - 24 months\r\nAdobe Illustrator- Exprience - 12 months\r\nCSS3- Exprience - 24 monthsCompany Details \r\ncompany - Virtuous Media Point\r\ndescription - \r\ncompany - CNC Web World\r\ndescription - Internship Program: At e-sense IT Solution pvt.ltd. Nagpur as a Web Designing and Developement.\r\n* Presented in Project Competition in Innovesta 15 of Priyadarshini Indira Gandhi College of Engineering, Nagpur.\r\n* Presented in National Level Paper Presentation in TECH-WAVE 2015 of S.R.M.C.E., Nagpur.\r\ncompany - e-sense IT Solution pvt.ltd\r\ndescription - Key Result Areas:\r\n* Designed websites solutions by studying information needs, conferring with users, and studying systems flow, data usage, and work processes.\r\n* Understood process requirements and provided use cases for business, functional & technical requirements.\r\n* Interacted with users for requirement gathering, prepared functional specifications and low-level design documents.\r\n* Participated in the Software Development Life cycle (SDLC) and Agile methodology right from requirement analysis,\r\n* Performed detailed design of modules along with their implementation, and documentation integrated software modules\r\nDeveloped by other team members.\r\n\r\nHighlights:\r\n* Developed various modules as per customer requirement and identified and fixed number of bugs related to code, Database connectivity, UI Defects and so on.\r\n* Analyzed and modified existing codes to incorporate a number of changes in the application / user requirements, wrote new codes as required.\r\n* Coded, implemented and integrated complex programs using technologies such as HTML5, CSS3, JavaScript, jQuery, bootstrap.\r\n* Having good command on Graphics designing with effective ideas.\r\n\r\nPROJECTS\r\n\r\n* www.nitka.com, Nagpur united corporation (admin), Mintmetrix.com, Tagline videos (admin), Smartbadge (admin): -\r\nIn all projects I have used technologies like HTML5, CSS3, Bootstrap, JavaScript, jQuery and text editor as sublime text.\r\n\r\n* www.shreekiaspack.co.in, www.3staragroproducts.com, www.luckystationery.co.in: - used technologies like HTML5, CSS3,\r\nBootstrap, javascript and text editor as notepad++.\r\n\r\n* Design various Logos, Brochures, Advertising Banners, Visiting Cards, Pamphlet, Hoardings etc.\r\n\r\nB.E. FINAL YEAR PROJECT\r\n\r\n* Major Project: -\r\n\r\nTitle: WEB BASED DISEASE DIAGNOSIS EXPERT SYSTEM.\r\nDuration: 1 Year\r\n\r\nDescription: In this project we provide a website in which doctor gives online consultation for particular disease. System gives better suggestions for any health problems.\r\n\r\n* Mini Project Development-\r\n\r\n* SHOPPING MANAGEMENT SYSTEM Developed in C++.\r\n\r\nCURRICULUM & EXTRA CURRICULUM ACTIVITIES\r\ncompany - FACE-IT\r\ndescription - Co-ordinator in project competition.",9


In [63]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', 
                                          do_lower_case=True)

In [64]:
encoded_data_train = tokenizer.batch_encode_plus(
    df[df.data_type=='train'].Resume.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=512, 
    return_tensors='pt'
)

encoded_data_val = tokenizer.batch_encode_plus(
    df[df.data_type=='val'].Resume.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=512, 
    return_tensors='pt'
)


input_ids_train = encoded_data_train['input_ids']
attention_masks_train = encoded_data_train['attention_mask']
labels_train = torch.tensor(df[df.data_type=='train'].label.values)

input_ids_val = encoded_data_val['input_ids']
attention_masks_val = encoded_data_val['attention_mask']
labels_val = torch.tensor(df[df.data_type=='val'].label.values)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [65]:
dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

In [66]:
len(dataset_train), len(dataset_val)

(481, 481)

In [67]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased",
                                                      num_labels=len(label_dict),
                                                      output_attentions=False,
                                                      output_hidden_states=False)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [68]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

batch_size = 3

dataloader_train = DataLoader(dataset_train, 
                              sampler=RandomSampler(dataset_train), 
                              batch_size=batch_size)

dataloader_validation = DataLoader(dataset_val, 
                                   sampler=SequentialSampler(dataset_val), 
                                   batch_size=batch_size)

In [69]:
from transformers import AdamW, get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(),
                  lr=1e-5, 
                  eps=1e-8)



In [70]:
epochs = 5

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=0,
                                            num_training_steps=len(dataloader_train)*epochs)

In [71]:
from sklearn.metrics import f1_score

def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return f1_score(labels_flat, preds_flat, average='weighted')

def accuracy_per_class(preds, labels):
    label_dict_inverse = {v: k for k, v in label_dict.items()}
    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {label_dict_inverse[label]}')
        print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')

def accuracy_count(preds, labels):
    label_dict_inverse = {v: k for k, v in label_dict.items()}
    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    count = 0
    count_t = 0
    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        count += len(y_preds[y_preds==label])
        count_t += len(y_true)
    print(f'Correct: {count}')
    print(f'Total: {count_t}\n')     
    return count, count_t

In [72]:
import random

seed_val = 17
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [73]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(device)

cuda


In [74]:
def evaluate(dataloader_val):

    model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    for batch in dataloader_val:
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }

        with torch.no_grad():        
            outputs = model(**inputs)
            
        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
    
    loss_val_avg = loss_val_total/len(dataloader_val) 
    
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)
            
    return loss_val_avg, predictions, true_vals

In [75]:
for epoch in tqdm(range(1, epochs+1)):
    
    model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for batch in progress_bar:

        model.zero_grad()
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }       

        outputs = model(**inputs)
        
        loss = outputs[0]
        loss_train_total += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})
         
        
    torch.save(model.state_dict(), f'finetuned_BERT_epoch_{epoch}.model')
        
    tqdm.write(f'\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader_train)            
    tqdm.write(f'Training loss: {loss_train_avg}')
    
    val_loss, predictions, true_vals = evaluate(dataloader_validation)
    val_f1 = f1_score_func(predictions, true_vals)
    tqdm.write(f'Validation loss: {val_loss}')
    tqdm.write(f'F1 Score (Weighted): {val_f1}')

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/161 [00:00<?, ?it/s]


Epoch 1
Training loss: 3.087625959645147
Validation loss: 2.768173586507762
F1 Score (Weighted): 0.14197988426917674


Epoch 2:   0%|          | 0/161 [00:00<?, ?it/s]


Epoch 2
Training loss: 2.5555868918851297
Validation loss: 2.250955013014515
F1 Score (Weighted): 0.5322145323785036


Epoch 3:   0%|          | 0/161 [00:00<?, ?it/s]


Epoch 3
Training loss: 2.0794198053964177
Validation loss: 1.8455201220808561
F1 Score (Weighted): 0.7975994942656546


Epoch 4:   0%|          | 0/161 [00:00<?, ?it/s]


Epoch 4
Training loss: 1.785295253214629
Validation loss: 1.6188831047982162
F1 Score (Weighted): 0.865603231765434


Epoch 5:   0%|          | 0/161 [00:00<?, ?it/s]


Epoch 5
Training loss: 1.615055935723441
Validation loss: 1.539924937745799
F1 Score (Weighted): 0.8710595439421351


In [76]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased",
                                                      num_labels=len(label_dict),
                                                      output_attentions=False,
                                                      output_hidden_states=False)

model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [77]:
model.load_state_dict(torch.load('finetuned_BERT_epoch_5.model', map_location=torch.device('cuda')))

<All keys matched successfully>

In [78]:
_, predictions, true_vals = evaluate(dataloader_validation)

In [79]:
accuracy_per_class(predictions, true_vals)

Class: Data Science
Accuracy: 20/20

Class: HR
Accuracy: 22/22

Class: Advocate
Accuracy: 10/10

Class: Arts
Accuracy: 18/18

Class: Web Designing
Accuracy: 23/23

Class: Mechanical Engineer
Accuracy: 13/20

Class: Sales
Accuracy: 20/20

Class: Health and fitness
Accuracy: 15/15

Class: Civil Engineer
Accuracy: 12/12

Class: Java Developer
Accuracy: 42/42

Class: Business Analyst
Accuracy: 8/14

Class: SAP Developer
Accuracy: 0/12

Class: Automation Testing
Accuracy: 3/13

Class: Electrical Engineering
Accuracy: 15/15

Class: Operations Manager
Accuracy: 20/20

Class: Python Developer
Accuracy: 24/24

Class: DevOps Engineer
Accuracy: 25/27

Class: Network Security Engineer
Accuracy: 13/13

Class: PMO
Accuracy: 15/15

Class: Database
Accuracy: 16/16

Class: Hadoop
Accuracy: 21/21

Class: ETL Developer
Accuracy: 14/20

Class: DotNet Developer
Accuracy: 2/14

Class: Blockchain
Accuracy: 20/20

Class: Testing
Accuracy: 35/35



In [80]:
x,y = accuracy_count(predictions, true_vals)
x/y


Correct: 426
Total: 481



0.8856548856548857

In [81]:
#70/30 split data
# Class: Data Science
# Accuracy: 12/12

# Class: HR
# Accuracy: 13/13

# Class: Advocate
# Accuracy: 6/6

# Class: Arts
# Accuracy: 11/11

# Class: Web Designing
# Accuracy: 14/14

# Class: Mechanical Engineer
# Accuracy: 12/12

# Class: Sales
# Accuracy: 12/12

# Class: Health and fitness
# Accuracy: 9/9

# Class: Civil Engineer
# Accuracy: 7/7

# Class: Java Developer
# Accuracy: 25/25

# Class: Business Analyst
# Accuracy: 8/8

# Class: SAP Developer
# Accuracy: 7/7

# Class: Automation Testing
# Accuracy: 5/8

# Class: Electrical Engineering
# Accuracy: 9/9

# Class: Operations Manager
# Accuracy: 12/12

# Class: Python Developer
# Accuracy: 14/14

# Class: DevOps Engineer
# Accuracy: 16/17

# Class: Network Security Engineer
# Accuracy: 8/8

# Class: PMO
# Accuracy: 9/9

# Class: Database
# Accuracy: 10/10

# Class: Hadoop
# Accuracy: 13/13

# Class: ETL Developer
# Accuracy: 12/12

# Class: DotNet Developer
# Accuracy: 1/8

# Class: Blockchain
# Accuracy: 12/12

# Class: Testing
# Accuracy: 21/21

# Correct: 278
# Total: 289
# 0.9619377162629758

# Wrong:
# Class: Automation Testing
# Accuracy: 5/8
# Class: DevOps Engineer
# Accuracy: 16/17
# Class: DotNet Developer
# Accuracy: 1/8

In [None]:
#50/50 split data
# Class: Data Science
# Accuracy: 20/20

# Class: HR
# Accuracy: 22/22

# Class: Advocate
# Accuracy: 10/10

# Class: Arts
# Accuracy: 18/18

# Class: Web Designing
# Accuracy: 23/23

# Class: Mechanical Engineer
# Accuracy: 13/20

# Class: Sales
# Accuracy: 20/20

# Class: Health and fitness
# Accuracy: 15/15

# Class: Civil Engineer
# Accuracy: 12/12

# Class: Java Developer
# Accuracy: 42/42

# Class: Business Analyst
# Accuracy: 8/14

# Class: SAP Developer
# Accuracy: 0/12

# Class: Automation Testing
# Accuracy: 3/13

# Class: Electrical Engineering
# Accuracy: 15/15

# Class: Operations Manager
# Accuracy: 20/20

# Class: Python Developer
# Accuracy: 24/24

# Class: DevOps Engineer
# Accuracy: 25/27

# Class: Network Security Engineer
# Accuracy: 13/13

# Class: PMO
# Accuracy: 15/15

# Class: Database
# Accuracy: 16/16

# Class: Hadoop
# Accuracy: 21/21

# Class: ETL Developer
# Accuracy: 14/20

# Class: DotNet Developer
# Accuracy: 2/14

# Class: Blockchain
# Accuracy: 20/20

# Class: Testing
# Accuracy: 35/35

# Correct: 426
# Total: 481
# 0.8856548856548857

# Class: Mechanical Engineer
# Accuracy: 13/20
# Class: Business Analyst
# Accuracy: 8/14
# Class: SAP Developer
# Accuracy: 0/12
# Class: Automation Testing
# Accuracy: 3/13
# Class: DevOps Engineer
# Accuracy: 25/27
# Class: ETL Developer
# Accuracy: 14/20
# Class: DotNet Developer
# Accuracy: 2/14