In [4]:
import os
import numpy as np
import random
import re

import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Input, Embedding, Bidirectional, LSTM, Dense

In [5]:

def clean_text(text):
    # Remove URLs
    text = re.sub(r'http\S+', '', text)
    # Remove punctuation and convert to lowercase
    text = re.sub(r'[^\w\s]', '', text).lower()
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text


def load_data(data_dir, max_length=None):

    # Load all files in the data directory
    all_files = os.listdir(data_dir)

    # Filter only the files with the .bio extension
    bio_files = [f for f in all_files if f.endswith('.bio')]

    # Initialize lists to hold sentences and labels
    sentences = []
    labels = []

    # Loop through each file and read the sentences and labels
    for file in bio_files:
        with open(os.path.join(data_dir, file), 'r', encoding='utf-8') as f:
            current_sentences = []
            current_labels = []
            for line in f:
                if line.strip() == '':
                    # If we encounter a blank line, it means we've reached the end of a sentence
                    if len(current_sentences) > 0:
                        # Add the current sentence and labels to the list
                        sentences.append(current_sentences)
                        labels.append(current_labels)
                        # Reset the current sentence and labels lists
                        current_sentences = []
                        current_labels = []
                else:
                    # Otherwise, split the line into its word and label components
                    word, label = line.strip().split('\t')
                    current_sentences.append(clean_text(word))
                    current_labels.append(label)

    # Shuffle the sentences and labels
    combined = list(zip(sentences, labels))
    random.shuffle(combined)
    sentences[:], labels[:] = zip(*combined)

    # Split the data into training, validation, and test sets
    num_sentences = len(sentences)
    num_train = int(num_sentences * 0.8)
    num_valid = int(num_sentences * 0.1)

    train_sentences = sentences[:num_train]
    train_labels = labels[:num_train]
    valid_sentences = sentences[num_train:num_train+num_valid]
    valid_labels = labels[num_train:num_train+num_valid]
    test_sentences = sentences[num_train+num_valid:]
    test_labels = labels[num_train+num_valid:]

    # Convert the labels to one-hot encoding
    unique_labels = set(element for sublist in labels for element in sublist)
    label_to_index = {label: id+1 for id, label in enumerate(sorted(unique_labels))}
    index_to_label = {id: label for label, id in label_to_index.items()}

    # Add the new label and ID to the dictionaries
    label_to_index['<PAD>'] = 0
    index_to_label[0] = '<PAD>'

    num_classes = len(index_to_label) - 1

    train_labels = [[label_to_index[label] for label in labels] for labels in train_labels]
    train_labels = pad_sequences(train_labels, maxlen=max_length, padding='post', value=num_classes)
    train_labels = to_categorical(train_labels, num_classes=num_classes+1)

    valid_labels = [[label_to_index[label] for label in labels] for labels in valid_labels]
    valid_labels = pad_sequences(valid_labels, maxlen=max_length, padding='post', value=num_classes)
    valid_labels = to_categorical(valid_labels, num_classes=num_classes+1)

    test_labels = [[label_to_index[label] for label in labels] for labels in test_labels]
    test_labels = pad_sequences(test_labels, maxlen=max_length, padding='post', value=num_classes)
    test_labels = to_categorical(test_labels, num_classes=num_classes+1)

    return (train_sentences, train_labels), (valid_sentences, valid_labels), (test_sentences, test_labels), label_to_index, index_to_label

In [6]:
(train_sentences, train_labels), (val_sentences, val_labels), (test_sentences, test_labels), label2id, id2label = load_data('./data/BIO_FILES', 200)

In [7]:
train_labels

array([[[0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        ...,
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.]],

       [[0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 1.],
        ...,
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.]],

       [[0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        ...,
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 1.],
        ...,
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0.

In [8]:
VOCAB_SIZE = 100000
EMBEDDING_DIM = 128
MAX_LENGTH = 200
NUM_CLASSES = 35
LSTM_UNITS = 64
NUM_EPOCHS = 10

In [9]:
# Convert the input sentences to sequences of word indices
tokenizer = Tokenizer(num_words=VOCAB_SIZE)
tokenizer.fit_on_texts(train_sentences)

train_sequences = tokenizer.texts_to_sequences(train_sentences)
val_sequences = tokenizer.texts_to_sequences(val_sentences)
test_sequences = tokenizer.texts_to_sequences(test_sentences)

# Pad the sequences to a fixed length
train_sequences_padded = pad_sequences(train_sequences, maxlen=MAX_LENGTH, padding='post', truncating='post')
val_sequences_padded = pad_sequences(val_sequences, maxlen=MAX_LENGTH, padding='post', truncating='post')
test_sequences_padded = pad_sequences(test_sequences, maxlen=MAX_LENGTH, padding='post', truncating='post')

In [10]:
# Define the model architecture
model = tf.keras.models.Sequential([
    Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, input_length=MAX_LENGTH),
    Bidirectional(LSTM(units=LSTM_UNITS, return_sequences=True)),
    Dense(NUM_CLASSES, activation='softmax')
])

#
# model = tf.keras.models.Sequential([
#     Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, input_length=MAX_LENGTH),
#     Bidirectional(LSTM(units=LSTM_UNITS, return_sequences=True)),
#     Dense(64, activation='relu'),
#     Dense(NUM_CLASSES, activation='softmax')
# ])


In [11]:
# Compile the model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])


In [12]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       (None, 200, 128)          12800000  
                                                                 
 bidirectional (Bidirectiona  (None, 200, 128)         98816     
 l)                                                              
                                                                 
 dense (Dense)               (None, 200, 35)           4515      
                                                                 
Total params: 12,903,331
Trainable params: 12,903,331
Non-trainable params: 0
_________________________________________________________________


In [13]:
# Train the model
model.fit(train_sequences_padded, train_labels, epochs=NUM_EPOCHS, validation_data=(val_sequences_padded, val_labels))

# Evaluate the model
test_loss, test_acc = model.evaluate(test_sequences_padded, test_labels)

# Print the test accuracy
print('Test accuracy:', test_acc)

Epoch 1/10


2023-04-03 10:09:02.430344: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Test accuracy: 0.9690219759941101


In [14]:
test_labels

array([[[0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        ...,
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.]],

       [[0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        ...,
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        ...,
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0.

In [17]:
import re
import string
def tokenize_text(text):
    # Tokenize the text into a list of words
    tokens = []
    for sentence in text.split('\n'):
        for word in sentence.split():
            # Remove trailing punctuation marks from the word
            while word and word[-1] in string.punctuation:
                word = word[:-1]
            tokens.append(word)
    return tokens

text = "The patient is a 55-year-old male with a history of hypertension and diabetes. He presented to the emergency department with complaints of chest pain, shortness of breath, and dizziness. The patient's blood pressure was 180/110 mmHg and his heart rate was 110 beats per minute. A 12-lead electrocardiogram showed ST-segment elevation in the anterior leads. The patient was diagnosed with an acute myocardial infarction and was immediately started on heparin and aspirin therapy. He underwent a cardiac catheterization and was found to have significant stenosis in the left anterior descending artery. He underwent percutaneous coronary intervention with stent placement and his symptoms improved. He was discharged home on aspirin, clopidogrel, atorvastatin, and lisinopril."


# tokens = re.findall(r'\b\w+\b', text)
tokens = tokenize_text(text)

sequence = tokenizer.texts_to_sequences([' '.join(clean_text(token) for token in tokens)])
padded_sequence = pad_sequences(sequence, maxlen=MAX_LENGTH, padding='post')

# Make the prediction
prediction = model.predict(np.array(padded_sequence))

# Decode the prediction
predicted_labels = np.argmax(prediction, axis=-1)
predicted_labels = [id2label[i] for i in predicted_labels[0]]

# Print the predicted named entities
print("Predicted Named Entities:")
for i in range(len(tokens)):
    print(f"{tokens[i]}: {''.join(predicted_labels[i])}")


Predicted Named Entities:
The: O
patient: O
is: O
a: O
55-year-old: B-Therapeutic_procedure
male: O
with: O
a: I-History
history: I-History
of: I-History
hypertension: I-History
and: I-History
diabetes: I-History
He: B-Clinical_event
presented: O
to: O
the: O
emergency: O
department: O
with: O
complaints: O
of: B-Biological_structure
chest: B-Sign_symptom
pain: B-Sign_symptom
shortness: I-Sign_symptom
of: I-Sign_symptom
breath: O
and: B-Sign_symptom
dizziness: O
The: O
patient's: B-Diagnostic_procedure
blood: I-Diagnostic_procedure
pressure: O
was: I-Lab_value
180/110: O
mmHg: O
and: B-Diagnostic_procedure
his: I-Diagnostic_procedure
heart: O
rate: B-Lab_value
was: I-Lab_value
110: I-Lab_value
beats: I-Lab_value
per: O
minute: B-Diagnostic_procedure
A: O
12-lead: B-Sign_symptom
electrocardiogram: I-Lab_value
showed: O
ST-segment: O
elevation: B-Biological_structure
in: O
the: O
anterior: O
leads: O
The: O
patient: O
was: O
diagnosed: O
with: B-Disease_disorder
an: I-Disease_disorder
ac

In [16]:
predicted_labels

['O',
 'O',
 'O',
 'O',
 'B-Therapeutic_procedure',
 'O',
 'O',
 'I-History',
 'I-History',
 'I-History',
 'I-History',
 'I-History',
 'I-History',
 'B-Clinical_event',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-Biological_structure',
 'B-Sign_symptom',
 'B-Sign_symptom',
 'I-Sign_symptom',
 'I-Sign_symptom',
 'O',
 'B-Sign_symptom',
 'O',
 'O',
 'B-Diagnostic_procedure',
 'I-Diagnostic_procedure',
 'O',
 'I-Lab_value',
 'O',
 'O',
 'B-Diagnostic_procedure',
 'I-Diagnostic_procedure',
 'O',
 'B-Lab_value',
 'I-Lab_value',
 'I-Lab_value',
 'I-Lab_value',
 'O',
 'B-Diagnostic_procedure',
 'O',
 'B-Sign_symptom',
 'I-Lab_value',
 'O',
 'O',
 'B-Biological_structure',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-Disease_disorder',
 'I-Disease_disorder',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-Medication',
 'O',
 'B-Medication',
 'O',
 'O',
 'O',
 'O',
 'B-Disease_disorder',
 'I-Diagnostic_procedure',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B-Sign_symptom',
 'O',
 'O',
 'B-Biological_stru

In [24]:
np.array(padded_sequence)

array([[   1,   11,  145,    5,  280,    6,    5,   34,    3,  218,    2,
         467,   19,   65,    8,    1,  224,  189,    6, 1500,    3,   50,
          74,  607,    3,  410,    2, 2032,    1,   79,   32,   63,    4,
         143,    2,   28,   78,  150,    4, 3456, 1232,  215, 1080,    5,
        4605,  594,   13, 2145,  758,    7,    1,  262,  582,    1,   11,
           4,   88,    6,   20,  200,  538, 1330,    2,    4,  546,  199,
          12, 1114,    2, 1355,   77,   19,  108,    5,  111,  910,    2,
           4,  137,    8,  216,  164,  529,    7,    1,   23,  262,  907,
          83,   19,  108,  864,  174, 1010,    6,  542, 1332,    2,   28,
         120,  323,   19,    4,  156,  434,   12, 1355,    2, 4116,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0, 

In [25]:
np.array(padded_sequence)

array([[   1,   11,  145,    5,  280,    6,    5,   34,    3,  218,    2,
         467,   19,   65,    8,    1,  224,  189,    6, 1500,    3,   50,
          74,  607,    3,  410,    2, 2032,    1,   79,   32,   63,    4,
         143,    2,   28,   78,  150,    4, 3456, 1232,  215, 1080,    5,
        4605,  594,   13, 2145,  758,    7,    1,  262,  582,    1,   11,
           4,   88,    6,   20,  200,  538, 1330,    2,    4,  546,  199,
          12, 1114,    2, 1355,   77,   19,  108,    5,  111,  910,    2,
           4,  137,    8,  216,  164,  529,    7,    1,   23,  262,  907,
          83,   19,  108,  864,  174, 1010,    6,  542, 1332,    2,   28,
         120,  323,   19,    4,  156,  434,   12, 1355,    2, 4116,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0, 

In [15]:
import re
import nltk
nltk.download('punkt')

[nltk_data] Error loading punkt: <urlopen error [SSL:
[nltk_data]     CERTIFICATE_VERIFY_FAILED] certificate verify failed:
[nltk_data]     unable to get local issuer certificate (_ssl.c:1131)>


False

In [16]:
# Split text into sentences
text = "The patient is a 55-year-old male with a history of hypertension and diabetes. He presented to the emergency department with complaints of chest pain, shortness of breath, and dizziness. The patient's blood pressure was 180/110 mmHg and his heart rate was 110 beats per minute. A 12-lead electrocardiogram showed ST-segment elevation in the anterior leads. The patient was diagnosed with an acute myocardial infarction and was immediately started on heparin and aspirin therapy. He underwent a cardiac catheterization and was found to have significant stenosis in the left anterior descending artery. He underwent percutaneous coronary intervention with stent placement and his symptoms improved. He was discharged home on aspirin, clopidogrel, atorvastatin, and lisinopril."

sentences = nltk.sent_tokenize(text)

# Process each sentence
for sentence in sentences:
    tokens = re.findall(r'\b\w+\b', sentence)
    sequence = tokenizer.texts_to_sequences([' '.join(clean_text(token) for token in tokens)])
    padded_sequence = pad_sequences(sequence, maxlen=MAX_LENGTH, padding='post')

    # Make the prediction
    prediction = model.predict(np.array(padded_sequence))

    # Decode the prediction
    predicted_labels = np.argmax(prediction, axis=-1)
    predicted_labels = [id2label[i] for i in predicted_labels[0]]

    # Print the predicted named entities
    print("Predicted Named Entities:")
    for i in range(len(tokens)):
        print(f"{tokens[i]}: {''.join(predicted_labels[i])}")


Predicted Named Entities:
The: O
patient: O
is: O
a: O
55: O
year: O
old: O
male: O
with: O
a: O
history: O
of: O
hypertension: O
and: O
diabetes: O
Predicted Named Entities:
He: O
presented: O
to: O
the: O
emergency: O
department: O
with: O
complaints: O
of: O
chest: O
pain: O
shortness: O
of: O
breath: O
and: O
dizziness: O
Predicted Named Entities:
The: O
patient: O
s: O
blood: O
pressure: O
was: O
180: O
110: O
mmHg: O
and: O
his: O
heart: O
rate: O
was: O
110: O
beats: O
per: O
minute: O
Predicted Named Entities:
A: O
12: O
lead: O
electrocardiogram: O
showed: O
ST: O
segment: O
elevation: O
in: O
the: O
anterior: O
leads: O
Predicted Named Entities:
The: O
patient: O
was: O
diagnosed: O
with: O
an: O
acute: O
myocardial: O
infarction: O
and: O
was: O
immediately: O
started: O
on: O
heparin: O
and: O
aspirin: O
therapy: O
Predicted Named Entities:
He: O
underwent: O
a: O
cardiac: O
catheterization: O
and: O
was: O
found: O
to: O
have: O
significant: O
stenosis: O
in: O
the: O
left

In [161]:
text = "The patient is a 55-year-old male with a history of hypertension and diabetes. He presented to the emergency department with complaints of chest pain, shortness of breath, and dizziness. The patient's blood pressure was 180/110 mmHg and his heart rate was 110 beats per minute. A 12-lead electrocardiogram showed ST-segment elevation in the anterior leads. The patient was diagnosed with an acute myocardial infarction and was immediately started on heparin and aspirin therapy. He underwent a cardiac catheterization and was found to have significant stenosis in the left anterior descending artery. He underwent percutaneous coronary intervention with stent placement and his symptoms improved. He was discharged home on aspirin, clopidogrel, atorvastatin, and lisinopril."

sequence = tokenizer.texts_to_sequences([text])
padded_sequence = pad_sequences(sequence, maxlen=MAX_LENGTH, padding='post')

# Make the prediction
prediction = model.predict(np.array(padded_sequence))

# Decode the prediction
predicted_labels = np.argmax(prediction, axis=-1)
predicted_labels = [id2label[i] for i in predicted_labels[0]]

# Print the predicted named entities
print("Predicted Named Entities:")
for i in range(len(text.split())):
    print(f"{text.split()[i]}: {''.join(predicted_labels[i])}")


Predicted Named Entities:
The: O
patient: O
is: O
a: O
55-year-old: O
male: O
with: B-Clinical_event
a: B-Clinical_event
history: O
of: O
hypertension: O
and: I-History
diabetes.: I-History
He: I-History
presented: I-History
to: I-History
the: B-Clinical_event
emergency: O
department: O
with: O
complaints: O
of: O
chest: O
pain,: O
shortness: B-Biological_structure
of: B-Sign_symptom
breath,: B-Sign_symptom
and: O
dizziness.: I-Sign_symptom
The: O
patient's: B-Sign_symptom
blood: O
pressure: O
was: B-Diagnostic_procedure
180/110: I-Diagnostic_procedure
mmHg: O
and: B-Lab_value
his: B-Lab_value
heart: B-Lab_value
rate: O
was: O
110: B-Diagnostic_procedure
beats: I-Diagnostic_procedure
per: O
minute.: B-Lab_value
A: I-Lab_value
12-lead: I-Lab_value
electrocardiogram: I-Lab_value
showed: O
ST-segment: O
elevation: B-Diagnostic_procedure
in: B-Diagnostic_procedure
the: O
anterior: O
leads.: B-Biological_structure
The: I-Sign_symptom
patient: O
was: O
diagnosed: B-Biological_structure
with:

In [172]:
import tensorflow as tf
from tensorflow_addons.text import crf_log_likelihood, crf_decode

class CRF(tf.keras.Model):
    def __init__(self, num_labels):
        super(CRF, self).__init__()
        self.num_labels = num_labels
        self.transition_params = tf.Variable(tf.random.normal(shape=(num_labels, num_labels)))

    @tf.function
    def call(self, inputs):
        input_features, label_ids, input_mask = inputs
        sequence_lengths = tf.reduce_sum(input_mask, axis=1)
        logits = tf.keras.layers.Dense(self.num_labels)(input_features)
        log_likelihood, self.transition_params = crf_log_likelihood(logits, label_ids, sequence_lengths,
                                                                    transition_params=self.transition_params)
        loss = -tf.reduce_mean(log_likelihood)
        pred_ids, _ = crf_decode(logits, self.transition_params, sequence_lengths)
        return loss, pred_ids


In [173]:
NUM_CLASSES

34

In [174]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow_addons.text import crf_log_likelihood, crf_decode

# Define the CRF model architecture
model = Sequential([
    layers.Input(shape=(MAX_LENGTH,), dtype=tf.int32),
    layers.Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, input_length=MAX_LENGTH),
    layers.Bidirectional(layers.LSTM(units=64, return_sequences=True)),
    layers.Dense(NUM_CLASSES),
    CRF(NUM_CLASSES)
])

# Compile the model
optimizer = Adam(lr=0.001)
model.compile(optimizer=optimizer, loss=crf_log_likelihood, metrics=[crf_decode])

OperatorNotAllowedInGraphError: Exception encountered when calling layer "crf_2" (type CRF).

in user code:

    File "/var/folders/1k/l9m7dlqd1knbl3m543_55s4h0000gn/T/ipykernel_8978/4124767977.py", line 12, in call  *
        input_features, label_ids, input_mask = inputs

    OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.


Call arguments received by layer "crf_2" (type CRF):
  • inputs=tf.Tensor(shape=(None, 128, 34), dtype=float32)

In [176]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow_addons.text import crf_log_likelihood, crf_decode

class CRF(tf.keras.layers.Layer):
    def __init__(self, num_labels):
        super(CRF, self).__init__()
        self.num_labels = num_labels
        self.transition_params = self.add_weight(shape=(num_labels, num_labels), initializer='glorot_uniform')

    @tf.function(input_signature=[tf.TensorSpec(shape=(None, None, None), dtype=tf.float32)])
    def call(self, inputs):
        logits = inputs
        path, _ = crf_decode(logits, self.transition_params, tf.reduce_sum(tf.cast(inputs != 0, dtype=tf.int32), axis=-1))
        return path


# Define the CRF model architecture
model = Sequential([
    layers.Input(shape=(MAX_LENGTH,), dtype=tf.int32),
    layers.Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, input_length=MAX_LENGTH),
    layers.Bidirectional(layers.LSTM(units=64, return_sequences=True)),
    layers.Dense(NUM_CLASSES),
    CRF(NUM_CLASSES)
])

# Compile the model
optimizer = Adam(lr=0.001)

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        loss, _ = model(x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


ValueError: Exception encountered when calling layer "crf_4" (type CRF).

in user code:

    File "/var/folders/1k/l9m7dlqd1knbl3m543_55s4h0000gn/T/ipykernel_8978/577521863.py", line 16, in call  *
        path, _ = crf_decode(logits, self.transition_params, tf.reduce_sum(tf.cast(inputs != 0, dtype=tf.int32), axis=-1))
    File "/Users/ajaykarthicksenthilkumar/dev/projects/NER-medical-text/lib/python3.8/site-packages/tensorflow_addons/text/crf.py", line 570, in _multi_seq_fn  *
        backpointers = tf.reverse_sequence(

    ValueError: Shape must be rank 1 but is rank 2 for '{{node cond/ReverseSequence}} = ReverseSequence[T=DT_INT32, Tlen=DT_INT32, batch_dim=0, seq_dim=1](cond/rnn/transpose_2, cond/Maximum)' with input shapes: [?,?,34], [?,?].


Call arguments received by layer "crf_4" (type CRF):
  • inputs=tf.Tensor(shape=(None, 128, 34), dtype=float32)

In [None]:

# Train the model
for epoch in range(10):
    epoch_loss = 0
    for x, y in train_dataset:
        loss = train_step(x, y)
        epoch_loss += loss
    print(f"Epoch {epoch+1} loss: {epoch_loss}")
