# COVID-19 Classification Model
---

## 📦 Imports and Setup

- **tensorflow.keras.applications.ResNet50**: Pre-trained ResNet50 model for transfer learning.  
- **tensorflow.keras.layers**  
  - `Dense`: Fully connected layer for classification.  
  - `GlobalAveragePooling2D`: Converts feature maps into a single vector per image.  
  - `BatchNormalization`: Normalizes activations to improve training stability.  
- **tensorflow.keras.models.Model**: To define and build the custom model architecture.  
- **tensorflow.keras.callbacks** 
  - `EarlyStopping`: Stops training when performance no longer improves.  
  - `ModelCheckpoint`: Saves the best model during training.  
  - `CSVLogger`: Logs training progress into a CSV file.  
- **tensorflow.keras.optimizers.Adam**: Optimizer used for training the model.  
- **utils**  
  - `load_data`: Loads datasets (train, validation).  
  - `visualize_samples`: Displays sample training images.  
  - `visualize_accuracy_loss`: Visualizes training accuracy and loss curves.  


In [None]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D , BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint ,CSVLogger
from tensorflow.keras.optimizers import Adam
from utils import load_data , visualize_samples ,visualize_accuracy_loss
import importlib
import Shared_vars    
importlib.reload(Shared_vars)  
from Shared_vars import DATA_DIR


-----

## 📂 Load and Explore Data

- `load_data`: Loads the training and validation datasets from the prepared directory.  
- `visualize_samples`: Displays a few sample images from the training set to confirm that data loading works correctly and to visually inspect the dataset.  

This step helps ensure that the data is correctly structured before model training begins.  


In [None]:
train_data , val_data =load_data(DATA_DIR, batch_size=32) 

visualize_samples(train_data)


-----

## 🧠 Build Model Architecture

- **Base Model**:  
  - `ResNet50` pre-trained on ImageNet.  
  - `include_top=False`: Excludes the fully connected head, keeping only convolutional layers.  
  - `input_shape=(224, 224, 3)`: Input size for chest X-ray images.  
  - `base_model.trainable = True`: Unfreezes the base model for fine-tuning.  

- **Custom Layers**:  
  - `GlobalAveragePooling2D`: Reduces feature maps into a single vector per image.  
  - **First Dense Block**:  
    - Dense(512, ReLU) → BatchNormalization  
  - **Second Dense Block**:  
    - Dense(256, ReLU) → BatchNormalization  

- **Output Layer**:  
  - Dense(4, Softmax): Outputs probabilities for 4 classes (COVID, Normal, Lung Opacity, Viral Pneumonia).  

This architecture combines the power of transfer learning from ResNet50 with additional dense layers for classification.  


In [None]:
base_model = ResNet50(
    weights='imagenet',
    include_top=False,
    input_shape=(224, 224, 3)
)
base_model.trainable = True  # Unfreeze the 

x = base_model.output
x = GlobalAveragePooling2D()(x)

# First dense block
x = Dense(512, activation='relu')(x)
x = BatchNormalization()(x)

# Second dense block
x = Dense(256, activation='relu')(x)
x = BatchNormalization()(x)

# Output layer
predictions = Dense(4, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)


-----

## ⚙️ Compile the Model

- **Optimizer**:  
  - `Adam` with a learning rate of `1e-5` for stable fine-tuning.  

- **Loss Function**:  
  - `categorical_crossentropy` since the task is multi-class classification.  

- **Metrics**:  
  - `accuracy` to monitor the model’s performance during training and validation.  

The model is now ready for training.  


In [None]:
model.compile(
    optimizer=Adam(learning_rate=1e-5),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

---

## 🏋️‍♂️ Train the Model

- **Callbacks**:  
  - `EarlyStopping`: Monitors `val_loss`, stops training if no improvement after 5 epochs, and restores best weights.  
  - `ModelCheckpoint`: Saves the best model as `best_model.h5` based on highest validation accuracy.  
  - `CSVLogger`: Logs training progress into `training_log.csv`.  

- **Training Configuration**:  
  - Dataset: `train_data` with validation on `val_data`.  
  - Epochs: 50  
  - Batch size: 32.

- **Visualization**:  
  - `visualize_accuracy_loss`: Plots training and validation accuracy and loss curves over epochs.  

This step trains the model while preventing overfitting and ensuring the best model is saved automatically.  


In [None]:
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

checkpoint = ModelCheckpoint("/output/best_model.h5", save_best_only=True, monitor='val_accuracy')

csv_logger = CSVLogger('training_log.csv', append=False)

history = model.fit(
    train_data,
    validation_data=val_data,
    epochs=50,
    callbacks=[early_stop, checkpoint , csv_logger]
)

visualize_accuracy_loss()