## Functions for extraction of FL, Heading, Communication Info

In [1]:
import pandas as pd
import torch 
from transformers import BertTokenizerFast, BertForTokenClassification

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df_numbers=pd.read_csv('numbers.csv')
for i in range(len(df_numbers['words'])):
    if '-' in df_numbers['words'][i]:
        df_numbers['words'][i]=df_numbers['words'][i].replace('-',' ')
    
numbers_list=[]
for numbers in df_numbers['words']:
    numbers_list.append(numbers)
    
decimal_list=['.','decimal','point']

#FL#
miscellaneous_list_fl=['ehm','ah','of','feet','and','below','altitude','to','then']
fl_action_list=['maintain','maintaining','climb','climbing','descend','descending','descent','passing','approaching','approach','below','above','request','requesting']
#FL#

#COMM#
miscellaneous_list_communication=['ehm','ah','the','on','frequency','now','correction','is']
#COMM

#HEADING#
miscellaneous_list_heading=['ehm','ah','of','maintain']
heading_action_list=['right','left','turn','continue','report','maintain','maintaining','present','request','fly','set']
#HEADING#

In [3]:
#location list for communication info
location_list=[
    'boston',
    'bratislava',
    'control',
    'geneva',
    'ground',
    'karlovy vary',
    'kbely',
    'krakow',
    'marseille',
    'frankfurt',
    'milan',
    'milano',
    'munich',
    'munchen',
    'ostrava',
    'paris',
    'praha',
    'radar',
    'rhein',
    'ruzyne',
    'tower',
    'us',
    'vienna',
    'reims',
    'warsaw',
    'wien',
    'zurich',
]

location_list_splitted=set()
for i in location_list:
    for j in i.split():
        location_list_splitted.add(j)

In [4]:
#function for extracting heading information
def extract_heading_info(utterance):
    
    heading_info=[]
    utterance_splitted = utterance.split()
    i=0
    
    while i<len(utterance_splitted):

        if utterance_splitted[i]=='heading':
            
            heading_temp=[]
            heading_temp.append('heading')

            for n in range(1,5):
                if (i-n)>=0:
                    if utterance_splitted[i-n] in heading_action_list:
                        heading_temp.insert(0,utterance_splitted[i-n].upper())

            if i<len(utterance_splitted)-1:

                i+=1

                while utterance_splitted[i] in numbers_list \
                or utterance_splitted[i] in miscellaneous_list_heading:

                    if utterance_splitted[i] in numbers_list:
                        heading_temp.append(utterance_splitted[i]) 
                        
                    if i<len(utterance_splitted)-1:
                        i+=1
                    else:
                        break
                            
                heading_info.extend(heading_temp)
                
            else:
                heading_info.extend(heading_temp)
                break

        else:
            i+=1
                
            
    
    if ' '.join(heading_info) == '':
        return None
    else:
        return ' '.join(heading_info)

In [5]:
#function to extract radio frequency or transponder code
def extract_communication_info(utterance):
    if 'contact' in utterance:
        
        contact_freq_info=[]
        utterance_splitted = utterance.split()
        i=0

        
        while i<len(utterance_splitted)-1:

            if utterance_splitted[i]=='contact' \
            and (utterance_splitted[i+1] in numbers_list \
            or utterance_splitted[i+1] in location_list_splitted \
            or utterance_splitted[i+1] in miscellaneous_list_communication) \
            and utterance_splitted[i-1] != 'radar':

                contact_freq_temp=[]
                contact_freq_temp.append('CONTACT')
                
                if i<len(utterance_splitted)-1:
                    i+=1
                    
                    while utterance_splitted[i] in numbers_list \
                    or utterance_splitted[i] in location_list_splitted \
                    or utterance_splitted[i] in decimal_list \
                    or utterance_splitted[i] in miscellaneous_list_communication:
                        
                        contact_freq_temp.append(utterance_splitted[i])
                        
                        if i<len(utterance_splitted)-1:
                            i+=1
                        else:
                            break
                            
                    contact_freq_info.extend(contact_freq_temp)
                
                else:
                    break
                    
            else:
                i+=1

        
        if ' '.join(contact_freq_info) == '':
            return None
        else:
            return ' '.join(contact_freq_info)   
            
        
    if 'squawk' in utterance or 'squak' in utterance:
        
        squawk_info=[]
        utterance_splitted = utterance.split()
        i=0
        
        
        while i<len(utterance_splitted):
            if (utterance_splitted[i]=='squawk' or utterance_splitted[i]=='squawking' or utterance_splitted[i]=='squak'):

                squawk_temp=[]
                squawk_temp.append('SQUAWK')
                
                if i<len(utterance_splitted)-1:
                    i+=1
               
                    while utterance_splitted[i] in numbers_list \
                    or utterance_splitted[i] in miscellaneous_list_communication:

                        squawk_temp.append(utterance_splitted[i])

                        if i<len(utterance_splitted)-1:
                            i+=1
                        else:
                            break

                    squawk_info.extend(squawk_temp)

                else:
                    squawk_info.extend(squawk_temp)
                    break

                    
            else:
                i+=1
        
        if ' '.join(squawk_info) == '':
            return None
        else:
            return ' '.join(squawk_info)     

In [6]:
#function to extract flight level
def extract_fl_info(utterance):
    if 'level' in utterance:
        
        fl_info=[]
        utterance_splitted = utterance.split()
        i=0
        
        while i<len(utterance_splitted):

            if utterance_splitted[i]=='level':

                fl_temp=[]
                fl_temp.append('level')

                for n in range(1,5):
                    if (i-n)>=0:
                        if utterance_splitted[i-n] in fl_action_list: 
                            fl_temp.insert(0,utterance_splitted[i-n].upper())
                        if utterance_splitted[i-n] == 'flight':
                            fl_temp.insert(0,'flight')
                if i<len(utterance_splitted)-1:

                    i+=1

                    while utterance_splitted[i] in numbers_list \
                    or utterance_splitted[i] in miscellaneous_list_fl:
                    
                        fl_temp.append(utterance_splitted[i]) 
                        if i<len(utterance_splitted)-1:
                            i+=1
                        else:
                            break
                            
                    fl_info.extend(fl_temp)
                
                else:
                    fl_info.extend(fl_temp)
                    break

            else:
                i+=1
                
        if ' '.join(fl_info) == '':
            return None
        else:
            return ' '.join(fl_info)   
        
    elif 'descend' in utterance or 'climb' in utterance:
        
        fl_info=[]
        utterance_splitted = utterance.split()
        i=0
    
        while i<len(utterance_splitted):

            if utterance_splitted[i] in ['descend','descending','climb','climbing']:

                fl_temp=[]
                fl_temp.append(utterance_splitted[i].upper())

                for n in range(1,5):
                    if (i-n)>=0:
                        if utterance_splitted[i-n] in fl_action_list and (i-n)>=0: 
                            fl_temp.insert(0,utterance_splitted[i-n].upper())
               
                if i<len(utterance_splitted)-1:

                    i+=1

                    while utterance_splitted[i] in numbers_list \
                    or utterance_splitted[i] in miscellaneous_list_fl \
                    or utterance_splitted[i] in fl_action_list:

                        if utterance_splitted[i] in fl_action_list:
                            fl_temp.append(utterance_splitted[i].upper())
                        else:
                            fl_temp.append(utterance_splitted[i])

                        if i<len(utterance_splitted)-1:
                            i+=1
                        else:
                            break
                    fl_info.extend(fl_temp)
                
                else:
                    fl_info.extend(fl_temp)
                    break

            else:
                i+=1
            
        if ' '.join(fl_info) == '':
            return None
        else:
            return ' '.join(fl_info)
                        

In [7]:
#call sign extraction
tokenizer = BertTokenizerFast.from_pretrained('bert_ner/bert_ner_tokenizer')

unique_labels=['B-call', 'I-call','O']
labels_to_ids = {k: v for v, k in enumerate(sorted(unique_labels))}
ids_to_labels = {v: k for v, k in enumerate(sorted(unique_labels))}

class BertModel(torch.nn.Module):

    def __init__(self):

        super(BertModel, self).__init__()

        self.bert = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=len(unique_labels))

    def forward(self, input_id, mask, label):

        output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)

        return output
    
model = BertModel()
model.load_state_dict(torch.load('bert_ner/ner_state_dict',map_location=torch.device('cpu')))


def align_word_ids(texts):
  
    tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)

    word_ids = tokenized_inputs.word_ids()

    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:

        if word_idx is None:
            label_ids.append(-100)

        elif word_idx != previous_word_idx:
            try:
                label_ids.append(1)
            except:
                label_ids.append(-100)
        else:
            try:
                label_ids.append(1 if label_all_tokens else -100)
            except:
                label_ids.append(-100)
        previous_word_idx = word_idx

    return label_ids


def evaluate_one_text(model, sentence):


    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if use_cuda:
        model = model.cuda()

    text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")

    mask = text['attention_mask'].to(device)
    input_id = text['input_ids'].to(device)
    label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)

    logits = model(input_id, mask, None)
    logits_clean = logits[0][label_ids != -100]

    predictions = logits_clean.argmax(dim=1).tolist()
    prediction_label = [ids_to_labels[i] for i in predictions]
    
    for i in range(len(prediction_label)):
        if prediction_label[i] == 'I-call':
            if i == 0:
                prediction_label[i] = 'O'
            else:
                if prediction_label[i-1] == 'I-call' or prediction_label[i-1] == 'B-call':
                    pass
                else:
                    prediction_label[i] = 'O'
                    
    return prediction_label
    
def extract_callsign(utterance):
    utterance_splitted = utterance.split()
    callsign=[]    
    prediction_label = evaluate_one_text(model,utterance)
    i=0
    while i< len(prediction_label):
        if len(callsign)==0:

            if prediction_label[i]=='B-call':
                callsign1=[utterance_splitted[i]]
                
                i+=1
                while i<len(prediction_label):
                    
                    if prediction_label[i]=='I-call':
                        callsign1.append(utterance_splitted[i])
                        i+=1
                    else:
                        break
                callsign.append((' '.join(callsign1)).upper())
            else:
                i+=1

        else:

            if prediction_label[i]=='B-call':
                callsign2=[utterance_splitted[i]]
                     
                i+=1
                while i<len(prediction_label):
                   
                    if prediction_label[i]=='I-call':
                        callsign2.append(utterance_splitted[i])
                        i+=1
                    else:
                        break
                callsign.append((' '.join(callsign2)).upper())
                break

            else:
                i+=1
        
            
    return callsign

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

## Test

In [9]:
metadata_path = 'metadata.csv'
metadata_df = pd.read_csv(metadata_path)
metadata_df.drop(columns=metadata_df.columns[0], axis=1, inplace=True)
metadata_df.drop(columns=metadata_df.columns[0], axis=1, inplace=True)
metadata_df = metadata_df.sample(frac=1,random_state=0).reset_index(drop=True)

split1 = int(len(metadata_df) * 0.95) #5% test set
df_test = metadata_df[split1:]
print(f"Size of the test set: {len(df_test)}")
df_test

Size of the test set: 1185


Unnamed: 0,transcript
22512,swiss one two four b turn left heading zero on...
22513,swissair nine three five nine is identified
22514,left heading three hundred austrian seven one ...
22515,praha easy four six two zero
22516,csa seven three two two line up runway one three
...,...
23692,k l m two six four rhein one two seven decimal...
23693,cross wayne at three thousand cleared i l s ap...
23694,ehm one six zero conti
23695,o l zu


In [11]:
for utterance in df_test['transcript']:
    print('UTTERANCE:',utterance)
    if len(extract_callsign(utterance)) == 0:
        pass
    elif len(extract_callsign(utterance)) == 1:
        print('CALLSIGN:',extract_callsign(utterance)[0])
    else:
        print('CALLSIGN 1:',extract_callsign(utterance)[0])
        print('CALLSIGN 2:',extract_callsign(utterance)[1])
    if extract_heading_info(utterance) != None:
        print('HEADING INFO:',extract_heading_info(utterance))
    if extract_communication_info(utterance) != None:
        print('COMMUNICATION INFO:',extract_communication_info(utterance))
    if extract_fl_info(utterance) != None:
        print('FLIGHT LEVEL INFO:',extract_fl_info(utterance))

    print('-'*100)

UTTERANCE: swiss one two four b turn left heading zero one zero
CALLSIGN: SWISS ONE TWO FOUR B
HEADING INFO: TURN LEFT heading zero one zero
----------------------------------------------------------------------------------------------------
UTTERANCE: swissair nine three five nine is identified
CALLSIGN: SWISSAIR NINE THREE FIVE NINE
----------------------------------------------------------------------------------------------------
UTTERANCE: left heading three hundred austrian seven one one z


KeyboardInterrupt: 