In [16]:
import sqlite3
import pandas as pd
import numpy as np 
import torchvision
import torch
from PIL import Image
import asyncio
import io
import os
import json
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from torch.utils.data import Subset
from deepmorpho.dl_folder.data_classes.morpho_dataset import EmbryoDataset
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
from torch.utils.data import DataLoader
import timm 

In [4]:
def connect_to_database(db_path):
    """Подключение к базе данных SQLite."""
    try:
        conn = sqlite3.connect(db_path)
        print("Успешно подключились к базе данных")
        return conn
    except sqlite3.Error as e:
        print(f"Ошибка при подключении к SQLite: {e}")
        return None

In [6]:
def extract_data(conn):
    try:
        cursor = conn.cursor()
        query = "SELECT wtf_wtl_id, wtf_frame, wtf_rcnn_text FROM well_timeline_frames;"
        cursor.execute(query)
        data = cursor.fetchall()
        images = []
        annotations = []
        for wtf_wtl_id, blob_data, json_data in data:
            image_bytes = io.BytesIO(blob_data)
            image = Image.open(image_bytes)
            try:
                annotation = json.loads(json_data)
                images.append((wtf_wtl_id, image))
                annotations.append((wtf_wtl_id, annotation))
            except TypeError:
                continue
        return images, annotations
    except sqlite3.Error as e:
        print(f"Ошибка при выполнении SQL: {e}")
    except IOError as e:
        print(f"Ошибка при открытии изображения: {e}")







In [7]:
db_path = 'C:\Work\CLASSES\SPRING2024\DeepMorphoDynamics\deepmorpho\so_deep.db'  
conn = connect_to_database(db_path)
if conn:
    images, annotations = extract_data(conn)
    conn.close()
else:
    print("Не удалось подключиться к базе данных.")

Успешно подключились к базе данных


In [8]:
result = []

for annotation in annotations:
    id = annotation[0]  
    data = annotation[1]  
    highest_prediction = max(data['predictions'], key=lambda x: x['prediction'])
    label = highest_prediction['label']
    prediction = highest_prediction['prediction']
    bbox = data['bboxes'][0]  
    result.append({'ID': id, 'label': label, 'prediction': prediction, 'bbox': bbox})

print(result[0])
images[0]


{'ID': 461, 'label': 'BLFL', 'prediction': 0.6353023648262024, 'bbox': [210, 118, 425, 323]}


(461, <PIL.BmpImagePlugin.BmpImageFile image mode=L size=500x500>)

In [9]:
all_classes = set()

for annotation in annotations:
    data = annotation[1]  
    for prediction in data['predictions']:
        all_classes.add(prediction['label'])

len(all_classes)

15

Now, we prepare dataset.

In [10]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

In [12]:
dataset = EmbryoDataset(result, images, transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
train_indices, test_indices, _, _ = train_test_split(
    range(len(dataset)),
    dataset.labels, 
    test_size=0.2,
    random_state=42
)

Now, VIS transformer. 

In [18]:
train_subset = Subset(dataset, train_indices)
test_subset = Subset(dataset, test_indices)

train_dataloader = DataLoader(train_subset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_subset, batch_size=32, shuffle=False)

model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=len(all_classes))
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()


In [20]:
def evaluate_accuracy(dataloader, model):
    correct = 0
    total = 0
    model.eval()  
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    model.train()  
    return 100 * correct / total

In [None]:
num_epochs = 3

model_dir = 'saved_models'
os.makedirs(model_dir, exist_ok=True)
for epoch in range(num_epochs):
    print(f'Epoch n: {epoch+1}/{num_epochs}')
    print('-' * 10)

    for batch_idx, (images, labels) in enumerate(train_dataloader):
        print(f'Working on batch  {batch_idx+1}/{len(train_dataloader)}')


        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    torch.save(model.state_dict(), f'{model_dir}/model_epoch_{epoch+1}.pth')
    accuracy = evaluate_accuracy(test_dataloader, model)
    print(f'Epoch [{epoch+1}/3]'
          f'Loss: {loss.item()}, Accuracy: {accuracy}%') 


Epoch n: 1/3
----------
Working on batch  1/290
Working on batch  2/290
Working on batch  3/290
Working on batch  4/290
Working on batch  5/290
Working on batch  6/290
Working on batch  7/290
Working on batch  8/290
Working on batch  9/290
Working on batch  10/290
Working on batch  11/290
Working on batch  12/290
Working on batch  13/290
Working on batch  14/290
Working on batch  15/290
Working on batch  16/290
Working on batch  17/290
Working on batch  18/290
Working on batch  19/290
Working on batch  20/290
Working on batch  21/290
Working on batch  22/290
Working on batch  23/290
Working on batch  24/290
Working on batch  25/290
Working on batch  26/290
Working on batch  27/290
Working on batch  28/290
Working on batch  29/290
Working on batch  30/290
Working on batch  31/290
Working on batch  32/290
Working on batch  33/290
Working on batch  34/290
Working on batch  35/290
Working on batch  36/290
Working on batch  37/290
Working on batch  38/290
Working on batch  39/290
Working on

In [15]:
def plot_roc_curve(dataloader):
    model.eval()
    test_probs = []
    test_targets = []
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            test_probs.extend(probabilities[:, 1].cpu().numpy())
            test_targets.extend(labels.cpu().numpy())

    fpr, tpr, _ = roc_curve(test_targets, test_probs)
    roc_auc = auc(fpr, tpr)
    
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    plt.show()