In [None]:
import os
import re
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, AutoModelForCausalLM, AutoProcessor, get_scheduler
from tqdm import tqdm
from typing import List, Dict, Any, Tuple
from peft import LoraConfig, get_peft_model
import matplotlib.pyplot as plt
from PIL import Image
from shapely.geometry import LineString

project_dir = os.getcwd() 

# Environment variables
os.environ["HUGGINGFACE_API_KEY"] = "Put your huggingface api key here"
os.environ['CUDA_VISIBLE_DEVICES']='0' # Put your gpu id here

# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 4
NUM_WORKERS = 0
CHECKPOINT = "microsoft/Florence-2-base-ft" # your can try larger model if you have enough gpu memory
REVISION = 'refs/pr/6'
start_epoch = '0'
ckpt_dir = f'{project_dir}/tulane_model_checkpoints'
output_dir = f'{project_dir}/output'



In [None]:
# Setup LoRA and florence-2 model
from peft import PeftModel, PeftConfig

EPOCHS = 50
LR = 4e-6

# Setup LoRA
config = LoraConfig(
    r=8,
    lora_alpha=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    bias="none",
    inference_mode=False,
    use_rslora=True,
    init_lora_weights="gaussian",
)


if start_epoch != '0':
    model_id = ckpt_dir + '/epoch_' + start_epoch # Continue to train
else:
    model_id = CHECKPOINT # Train from scratch
    
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(DEVICE)

if start_epoch == '0':
    peft_model = get_peft_model(model, config)
else:
    peft_model = PeftModel.from_pretrained(model, model_id, is_trainable=True) # Continue to train

peft_model.print_trainable_parameters()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)


In [None]:
# Setup dataset and dataloaders

class LaneDetectionDataset(Dataset):
    def __init__(self, jsonl_file_path: str, image_directory_path: str, test_mode: bool = False):
        self.jsonl_file_path = jsonl_file_path
        self.image_directory_path = image_directory_path
        self.test_mode = test_mode
        self.entries = self._load_entries()

    def _load_entries(self) -> List[Dict[str, Any]]:
        if self.test_mode:
            return [{"image": f, "prefix": "<OD_LANE>", "suffix": ""} for f in os.listdir(self.image_directory_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
        
        entries = []
        with open(self.jsonl_file_path, 'r') as file:
            for line in file:
                entries.append(json.loads(line))
        return entries

    def __len__(self) -> int:
        return len(self.entries)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, Dict[str, Any]]:
        entry = self.entries[idx]
        image_path = os.path.join(self.image_directory_path, entry['image'])
        image = Image.open(image_path)
        
        if self.test_mode:
            return image, None
        
        return image, {'prefix': entry['prefix'], 'suffix': entry['suffix']}

def collate_fn(batch):
    images, data = zip(*batch)
    if data[0] is None:  # Test mode
        return images
    
    questions = [item['prefix'] for item in data]
    answers = [item['suffix'] for item in data]
    inputs = processor(text=questions, images=images, return_tensors="pt", padding=True).to(DEVICE)
    return inputs, answers

# Initialize datasets and dataloaders
train_dataset = LaneDetectionDataset(
    jsonl_file_path=f"{project_dir}/florence2lane_data/tulane_florence/train/annotations.json",
    image_directory_path=f"{project_dir}/florence2lane_data/tulane_florence/train/images",
    test_mode=False
)
val_dataset = LaneDetectionDataset(
    jsonl_file_path=f"{project_dir}/florence2lane_data/tulane_florence/valid/annotations.json",
    image_directory_path=f"{project_dir}/florence2lane_data/tulane_florence/valid/images",
    test_mode=False
)
test_dataset = LaneDetectionDataset(
    jsonl_file_path=None,
    image_directory_path=f"{project_dir}/florence2lane_data/tulane_florence/test/images", # put your test images here
    test_mode=True
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS)


In [None]:

def parse_lanes(lane_string, imgHgt, imgWdt):
    """
    Parse the lane string into a list of polylines.
    """
    parsed_lane = {'labels': [], 'lanes': []}
    lanes = [lane for lane in lane_string.split('lane') if lane]
    
    for i, lane in enumerate(lanes):
        parsed_lane['labels'].append('lane')
        coords = re.findall(r'<loc_(\d+)>', lane)
        if len(coords) % 2 != 0:
            coords = coords[:-1]
        
        coord_pairs = [(np.clip(float(coords[i])/1000.0, 0.0, 1.0)*(imgWdt-1), 
                        np.clip(float(coords[i+1])/1000.0, 0.0, 1.0)*(imgHgt-1)) 
                       for i in range(0, len(coords), 2)]
        
        line = LineString(coord_pairs)
        parsed_lane['lanes'].append(line)
    
    return parsed_lane

def plot_polylines_on_image(parsed_lanes, img, img_idx: int, save_img: bool = False, epoch: int = 0):
    """
    Plot the polylines on the image.
    """
    fig, ax = plt.subplots()
    ax.imshow(img)
    cmap = plt.cm.get_cmap('tab10')
    
    for i, (line, label) in enumerate(zip(parsed_lanes['lanes'], parsed_lanes['labels'])):
        x, y = line.xy
        ax.plot(x[0], y[0], 'o', color='red')
        ax.plot(x[-1], y[-1], 'o', color='blue')
        ax.plot(x, y, label=f"lane {i}", color=cmap(i), alpha=0.8)
    
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles, labels, loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.)    

    plt.tight_layout()  
    
    if save_img:
        dir = f"{output_dir}/inference_results/epoch_{epoch}"
        os.makedirs(dir, exist_ok=True)
        img_name = f"image_{img_idx}_epoch_{epoch}.png"
        plt.savefig(f"{dir}/{img_name}", bbox_inches='tight')  
    
    plt.show()
    plt.close(fig) 

def florence2_inference_results(model, dataset: LaneDetectionDataset, count: int, epoch: int, save_img: bool = False):
    """
    Inference the results of the model.
    """
    count = min(count, len(dataset))
    for i in range(count):
        image, data = dataset[i]
        if data is None:  # Test mode
            prefix = "<OD_LANE>"
        else:
            prefix = data['prefix']
        
        imgWdt, imgHgt = image.size
        inputs = processor(text=prefix, images=image, return_tensors="pt").to(DEVICE)
    
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            early_stopping=False,
            do_sample=False,
            num_beams=3,
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed_answer = processor.post_process_generation(
            generated_text,
            task='<OD_LANE>',
            image_size=(image.width, image.height))
        
        lane_string = parsed_answer['<OD_LANE>']
        lane_result = parse_lanes(lane_string, imgHgt, imgWdt)
        plot_polylines_on_image(lane_result, image, img_idx=i, save_img=save_img, epoch=epoch)
        
    
    return lane_result

def train_model(train_loader, val_loader, model, processor, epochs=EPOCHS, start_epoch= start_epoch, model_id = model_id, lr=LR):
    """
    Train the model.
    """
    optimizer = AdamW(model.parameters(), lr=lr)
    num_training_steps = epochs * len(train_loader)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )

    for epoch in range(epochs - start_epoch):
        model.train()
        train_loss = 0
        for inputs, answers in tqdm(train_loader, desc=f"Training Epoch {start_epoch + epoch + 1}/{epochs}"):
            input_ids = inputs["input_ids"]
            pixel_values = inputs["pixel_values"]
            labels = processor.tokenizer(
                text=answers,
                return_tensors="pt",
                padding=True,
                return_token_type_ids=False
            ).input_ids.to(DEVICE)

            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f"Average Training Loss: {avg_train_loss}")

        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for inputs, answers in tqdm(val_loader, desc=f"Validation Epoch {start_epoch + epoch + 1}/{epochs}"):
                input_ids = inputs["input_ids"]
                pixel_values = inputs["pixel_values"]
                labels = processor.tokenizer(
                    text=answers,
                    return_tensors="pt",
                    padding=True,
                    return_token_type_ids=False
                ).input_ids.to(DEVICE)

                outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
                loss = outputs.loss
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"Average Validation Loss: {avg_val_loss}")

        florence2_inference_results(model, val_loader.dataset, 2, start_epoch + epoch, save_img=False)

        ckpt_save_dir = f"{ckpt_dir}/epoch_{start_epoch + epoch+1}"
        os.makedirs(ckpt_save_dir, exist_ok=True)
        model.save_pretrained(ckpt_save_dir)
        processor.save_pretrained(ckpt_save_dir)

In [None]:
# Test before training, if starting from scratch, the model have no idea what <OD_LANE> means
florence2_inference_results(peft_model, val_dataset, epoch= 0, count= 2, save_img=False)

In [None]:
# Training the model
train_model(train_loader, val_loader, peft_model, processor, epochs=EPOCHS, start_epoch= int(start_epoch), model_id = model_id, lr=LR)

In [None]:
# Validation after training
florence2_inference_results(peft_model, val_dataset, epoch= EPOCHS, count= 10, save_img=True)