In [None]:
from src.visualise import settings
from src.visualise.plot import plot_data
from src.data.paths import project_dir
from src.data.analysis import read_tiff_img, Circle, create_circular_mask
from src.data.detector import find_circle_hough_method, img_for_circle_detection

import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage as ndi
import re

from dataclasses import dataclass, field
from pathlib import Path

from copy import deepcopy

# Data structures

In [None]:
@dataclass(frozen=True)
class DetectorImage:
    image: np.ndarray
    path: Path

    @property
    def init_circle(self) -> Circle:
        return Circle(x=self.image.shape[0]//2, y=self.image.shape[1]//2, r=100)

@dataclass(frozen=True)
class DetectorData:
    raw: DetectorImage
    lv: DetectorImage
    det_no: int
    circle: Circle = field(default=Circle())

@dataclass(frozen=True)
class DetectorDataCollection:
    path: Path
    data: dict[int, DetectorData] = field(default_factory=dict)

    def __post_init__(self):
        if not self.data:
            self._load_data()

    def _load_data(self):
        for file_path in sorted(self.path.iterdir()):
            if file_path.name.endswith('lv'):
                # get detector data
                det_id = re.findall(r'\d+', file_path.name)[0]
                det_no = int(det_id)
                # live view images
                lv_path = next(file_path.glob('**/*tif'))
                lv_data = read_tiff_img(lv_path, border_px=0)
                lv_image = DetectorImage(image=lv_data, path=lv_path)
                # raw data images
                try:
                    raw_path = next((self.path / det_id).glob('**/*tif'))
                    raw_data = read_tiff_img(raw_path, border_px=0)
                    raw_image = DetectorImage(image=raw_data, path=raw_path)
                    det_data = DetectorData(raw=raw_image, lv=lv_image, det_no=det_no)
                    self.data[det_no] = det_data
                    print(f"{det_no} ", end='')
                except StopIteration:
                    print(f"missing_{det_no} ", end='')

In [None]:
raw_path = project_dir / 'data' / 'raw' / '2024-02-20'
proton_raw_data = DetectorDataCollection(path=raw_path)

In [None]:
co60_path = project_dir / 'data' / 'raw' / 'Co60'
co60_raw_data = DetectorDataCollection(path=co60_path)

# Proton raw data

In [None]:
fig, ax = plot_data(proton_raw_data.data[1].lv.image, circle_px=proton_raw_data.data[1].lv.init_circle)

In [None]:
fig, ax = plot_data(proton_raw_data.data[1].raw.image, circle_px=proton_raw_data.data[1].raw.init_circle)

# Co60 raw data

In [None]:
fig, ax = plot_data(co60_raw_data.data[1].lv.image, circle_px=co60_raw_data.data[1].lv.init_circle)

In [None]:
fig, ax = plot_data(co60_raw_data.data[1].raw.image, circle_px=co60_raw_data.data[1].raw.init_circle)

## Background

In [None]:
background_path = next(raw_path.parent.glob('**/*background*/**/**/*tif'))
background_data = DetectorImage(path=background_path, image=read_tiff_img(background_path, border_px=0))

In [None]:
fig, ax = plot_data(background_data.image, circle_px=background_data.init_circle)

# Background subtraction

In [None]:
proton_bg_sub_data = DetectorDataCollection(path=proton_raw_data.path, data=deepcopy(proton_raw_data.data))
for data in proton_bg_sub_data.data.values():

    # out data are save as uint16, dataclasses are frozen
    # its not straightforward to use np.crop(0) or cast to int64
    # therefore we shift data up, perform background subtraction, crop negative values and shift back
    np.add(data.raw.image, background_data.image.max(), out=data.raw.image)    
    np.subtract(data.raw.image, background_data.image, out=data.raw.image)
    np.clip(data.raw.image, a_min=background_data.image.max(), a_max=None, out=data.raw.image)
    np.subtract(data.raw.image, background_data.image.max(), out=data.raw.image)


In [None]:
co60_bg_sub_data = DetectorDataCollection(path=co60_raw_data.path, data=deepcopy(co60_raw_data.data))
for data in co60_bg_sub_data.data.values():

    # out data are save as uint16, dataclasses are frozen
    # its not straightforward to use np.crop(0) or cast to int64
    # therefore we shift data up, perform background subtraction, crop negative values and shift back
    np.add(data.raw.image, background_data.image.max(), out=data.raw.image)    
    np.subtract(data.raw.image, background_data.image, out=data.raw.image)
    np.clip(data.raw.image, a_min=background_data.image.max(), a_max=None, out=data.raw.image)
    np.subtract(data.raw.image, background_data.image.max(), out=data.raw.image)

In [None]:
fig, ax = plot_data(proton_raw_data.data[1].raw.image, circle_px=proton_raw_data.data[1].raw.init_circle)

In [None]:
fig, ax = plot_data(proton_bg_sub_data.data[1].raw.image, circle_px=proton_bg_sub_data.data[1].raw.init_circle)

# Detector discovery

In [None]:
det_data_dict = {}
for data in proton_bg_sub_data.data.values():
    lv_for_detect = img_for_circle_detection(data.lv.image)
    circle = find_circle_hough_method(lv_for_detect)
    det_data = DetectorData(raw=data.raw, lv=data.lv, det_no=data.det_no, circle=circle)
    det_data_dict[data.det_no] = det_data
    print(f"{data.det_no} ", end='')
proton_det_data = DetectorDataCollection(path=proton_raw_data.path, data=det_data_dict)

In [None]:
det_data_dict = {}
for data in co60_bg_sub_data.data.values():
    lv_for_detect = img_for_circle_detection(data.lv.image)
    circle = find_circle_hough_method(lv_for_detect)
    det_data = DetectorData(raw=data.raw, lv=data.lv, det_no=data.det_no, circle=circle)
    det_data_dict[data.det_no] = det_data
    print(f"{data.det_no} ", end='')
co60_det_data = DetectorDataCollection(path=co60_raw_data.path, data=det_data_dict)

In [None]:
fig, ax = plot_data(proton_det_data.data[1].raw.image, circle_px=proton_det_data.data[1].circle)

In [None]:
mean_radius = np.mean([data.circle.r for data in proton_det_data.data.values()])
mean_radius_mm = 10.0
px_to_mm = mean_radius_mm / mean_radius
px_to_mm

In [None]:
fig, ax = plt.subplots(2, 3, figsize=(12, 8))
for data in proton_det_data.data.values():
    ax[0, 0].plot(data.det_no, data.circle.x, '.')
    ax[0, 1].plot(data.det_no, data.circle.y, '.')
    ax[0, 2].plot(data.det_no, data.circle.r, '.')
for data in co60_det_data.data.values():
    ax[1, 0].plot(data.det_no, data.circle.x, '.')
    ax[1, 1].plot(data.det_no, data.circle.y, '.')
    ax[1, 2].plot(data.det_no, data.circle.r, '.')
for a in ax.flat:
    a.grid()
for a in ax[0]:
    a.set_xlabel('Proton det no')
for a in ax[1]:
    a.set_xlabel('Co60 det no')
for a in ax[:, 0]:
    a.set_ylabel('Det center x [px]')
for a in ax[:, 1]:
    a.set_ylabel('Det center y [px]')
for a in ax[:, 2]:
    a.set_ylabel('Det radius [px]')
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(2, 3, figsize=(12, 8))
for data in proton_det_data.data.values():
    ax[0, 0].plot(data.det_no, data.circle.x * px_to_mm, '.')
    ax[0, 1].plot(data.det_no, data.circle.y * px_to_mm, '.')
    ax[0, 2].plot(data.det_no, data.circle.r * px_to_mm, '.')
for data in co60_det_data.data.values():
    ax[1, 0].plot(data.det_no, data.circle.x * px_to_mm, '.')
    ax[1, 1].plot(data.det_no, data.circle.y * px_to_mm, '.')
    ax[1, 2].plot(data.det_no, data.circle.r * px_to_mm, '.')
for a in ax.flat:
    a.grid()
for a in ax[0]:
    a.set_xlabel('Proton det no')
for a in ax[1]:
    a.set_xlabel('Co60 det no')
for a in ax[:, 0]:
    a.set_ylabel('Det center x [mm]')
for a in ax[:, 1]:
    a.set_ylabel('Det center y [mm]')
for a in ax[:, 2]:
    a.set_ylabel('Det radius [mm]')
fig.tight_layout()

## Detector cutting

In [None]:
min_radius_protons = min([data.circle.r for data in proton_det_data.data.values()])
min_radius_co60 = min([data.circle.r for data in co60_det_data.data.values()])
min_radius_all = np.floor(min(min_radius_protons, min_radius_co60)).astype(int)

print(f"min radius - protons: {min_radius_protons:.2f}, Co60: {min_radius_co60:.2f}, all: {min_radius_all}")
print(f"min radius - protons: {min_radius_protons * px_to_mm:.2f} [mm], Co60: {min_radius_co60 * px_to_mm:.2f} [mm], all: {min_radius_all * px_to_mm:.2f} [mm]")

In [None]:
def cut_detector(det_data: DetectorData, min_radius: int, factor: float = 1.2) -> DetectorData:
    lower_x = np.ceil(det_data.circle.x - factor*min_radius).astype(int)
    upper_x = np.floor(det_data.circle.x + factor*min_radius).astype(int)
    lower_y = np.ceil(det_data.circle.y - factor*min_radius).astype(int)
    upper_y = np.floor(det_data.circle.y + factor*min_radius).astype(int)
    
    new_circle = Circle(x=det_data.circle.x - lower_x, y=det_data.circle.y - lower_y, r=det_data.circle.r)
    cut_raw_image = det_data.raw.image[lower_y:upper_y, lower_x:upper_x][:]
    cut_lv_image = det_data.lv.image[lower_y:upper_y, lower_x:upper_x][:]

    return DetectorData(raw=DetectorImage(image=cut_raw_image, path=det_data.raw.path), lv=DetectorImage(image=cut_lv_image, path=det_data.lv.path), det_no=det_data.det_no, circle=new_circle)

In [None]:
c1 = cut_detector(proton_det_data.data[1], min_radius_all)
fig, ax = plot_data(c1.lv.image, circle_px=c1.circle)

In [None]:
det_data_dict = {}
for data in proton_det_data.data.values():
    det_data_dict[data.det_no] = cut_detector(data, min_radius_all)
    print(f"{data.det_no} ", end='')
proton_det_cut_data = DetectorDataCollection(path=proton_det_data.path, data=det_data_dict)

In [None]:
det_data_dict = {}
for data in co60_det_data.data.values():
    det_data_dict[data.det_no] = cut_detector(data, min_radius_all)
    print(f"{data.det_no} ", end='')
co60_det_cut_data = DetectorDataCollection(path=co60_det_data.path, data=det_data_dict)

In [None]:
fig, ax = plot_data(proton_det_cut_data.data[19].raw.image, circle_px=proton_det_cut_data.data[1].circle)

# Mean signal inspection

In [None]:
min_radius_all
singal_radius = 0.5 * min_radius_all
singal_radius

In [None]:
from collections import defaultdict
import pandas as pd
df_data = defaultdict(list)
for data in proton_det_cut_data.data.values():
    mask = create_circular_mask(img=data.raw.image, circle_px=Circle(x=data.circle.x, y=data.circle.y, r=singal_radius))
    df_data["det_no"] += [int(data.det_no)]
    df_data["mean_signal"] += [np.mean(data.raw.image[mask])]
    df_data["std_signal"] += [np.std(data.raw.image[mask])]
    df_data["radiation"] += ["proton"]
for data in co60_det_cut_data.data.values():
    mask = create_circular_mask(img=data.raw.image, circle_px=Circle(x=data.circle.x, y=data.circle.y, r=singal_radius))
    df_data["det_no"] += [int(data.det_no)]
    df_data["mean_signal"] += [np.mean(data.raw.image[mask])]
    df_data["std_signal"] += [np.std(data.raw.image[mask])]
    df_data["radiation"] += ["Co60"]
df = pd.DataFrame(df_data)
df.sort_values(by=["radiation","det_no"], inplace=True)
df

In [None]:
df[df.radiation == 'proton'].plot(x='det_no', y='mean_signal', yerr='std_signal', kind='bar', figsize=(12, 8), grid=True)

In [None]:
df[df.radiation == 'Co60'].plot(x='det_no', y='mean_signal', yerr='std_signal', kind='bar', figsize=(12, 8), grid=True)

# Good dataset selection

In [None]:
det_data_dict = {}
for data in proton_det_cut_data.data.values():
    if data.det_no <= 19:
        det_data_dict[data.det_no] = data
        print(f"{data.det_no} ", end='')
proton_det_sel_data = DetectorDataCollection(path=proton_det_cut_data.path, data=det_data_dict)

In [None]:
det_data_dict = {}
for data in co60_det_cut_data.data.values():
    if data.det_no <= 19:
        det_data_dict[data.det_no] = data
        print(f"{data.det_no} ", end='')
co60_det_sel_data = DetectorDataCollection(path=co60_det_cut_data.path, data=det_data_dict)

In [None]:
import random
import math
num_circles = 1000
small_radius = 10
shift_x, shift_y = [], []
for _ in range(num_circles):
    angle = random.uniform(0, 2*np.pi)
    distance = random.uniform(0, 0.8*min_radius_all-small_radius)
    shift_x.append(distance * math.cos(angle))
    shift_y.append(distance * math.sin(angle))

In [None]:
df_data = defaultdict(list)
for data in proton_det_cut_data.data.values():
    for circle_no in range(num_circles):
        circle = Circle(x=data.circle.x + shift_x[circle_no], y=data.circle.y + shift_y[circle_no], r=small_radius)
        mask = create_circular_mask(img=data.raw.image, circle_px=circle)
        df_data["det_no"] += [int(data.det_no)]
        df_data["mean_signal"] += [np.mean(data.raw.image[mask])]
        df_data["std_signal"] += [np.std(data.raw.image[mask])]
        df_data["radiation"] += ["proton"]
        df_data["circle_no"] += [circle_no]
for data in co60_det_cut_data.data.values():
    for circle_no in range(num_circles):
        circle = Circle(x=data.circle.x + shift_x[circle_no], y=data.circle.y + shift_y[circle_no], r=small_radius)
        mask = create_circular_mask(img=data.raw.image, circle_px=circle)
        df_data["det_no"] += [int(data.det_no)]
        df_data["mean_signal"] += [np.mean(data.raw.image[mask])]
        df_data["std_signal"] += [np.std(data.raw.image[mask])]
        df_data["radiation"] += ["Co60"]
        df_data["circle_no"] += [circle_no]
df = pd.DataFrame(df_data)
df.sort_values(by=["radiation","det_no"], inplace=True)
df

In [None]:
import seaborn as sns
sns.set_theme(style="whitegrid")
sns.relplot(x="det_no", y="mean_signal", hue="circle_no", col="radiation", data=df)


In [None]:
df2 = pd.DataFrame()
df2["mean"] = df.groupby(["radiation", "det_no"])["mean_signal"].mean()
df2["std"] = df.groupby(["radiation", "det_no"])["mean_signal"].std()
df2

In [None]:
sns.relplot(data=df2.reset_index(), x='det_no', y='mean', hue='radiation', style='radiation', kind='line')

In [None]:
sns.relplot(data=df2.reset_index(), x='det_no', y='std', hue='radiation', style='radiation', kind='line')

In [None]:
# Reset the index if df2 is a MultiIndex DataFrame
df2_reset = df2.reset_index()

# Create a figure and axis object
plt.figure(figsize=(10, 6))

# Plot the mean values with line and scatter points
sns.lineplot(data=df2_reset, x='det_no', y='mean', hue='radiation', style='radiation', markers=True, dashes=False)

# Add error bars
for radiation_type in df2_reset['radiation'].unique():
    subset = df2_reset[df2_reset['radiation'] == radiation_type]
    plt.errorbar(subset['det_no'], subset['mean'], yerr=subset['std'], fmt='none', capsize=5, label=f'{radiation_type} Error')

plt.title('Mean Values and Standard Deviations by Detector Number')
plt.xlabel('Detector Number')
plt.ylabel('Measured Value')
plt.legend(title='Radiation Type')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Assuming df2 is already defined and includes 'mean' and 'std' columns for each 'det_no' for different 'radiation' conditions.

# Reset the index if df2 is a MultiIndex DataFrame
df2_reset = df2.reset_index()

# Filter the data for 'Co60' radiation
df2_co60 = df2_reset[df2_reset['radiation'] == 'Co60']

# Create a figure and axis object
plt.figure(figsize=(10, 6))

# Plot the mean values with line and scatter points for Co60
sns.lineplot(data=df2_co60, x='det_no', y='mean', hue='radiation', style='radiation', markers=True, dashes=False)

# Add error bars for Co60
plt.errorbar(df2_co60['det_no'], df2_co60['mean'], yerr=df2_co60['std'], fmt='o', capsize=5, label='Co60 Error')

plt.title('Mean Values and Standard Deviations for Co60 Radiation')
plt.xlabel('Detector Number')
plt.ylabel('Measured Value')
plt.legend(title='Radiation Type')

In [None]:
plot_data((proton_data.astype(np.int64)-background_data.astype(np.int64)).clip(0,None), path='', circle_px=Circle(c.x, c.y, 80))

In [None]:
dose_proton_Gy = 5

proton_data_bg_removed = (proton_data.astype(np.int64)-background_data.astype(np.int64)).clip(0,None)
mask_for_circle = create_circular_mask(img=proton_data_bg_removed, circle_px=Circle(c.x, c.y, 80))
proton_data_bg_removed_mean = np.mean(proton_data_bg_removed[mask_for_circle], where=proton_data_bg_removed[mask_for_circle]>0)
proton_data_bg_removed_std = np.std(proton_data_bg_removed[mask_for_circle], where=proton_data_bg_removed[mask_for_circle]>0)
proton_data_bg_removed_mean

proton_data_bg_removed_mean / dose_proton_Gy, proton_data_bg_removed_std / proton_data_bg_removed_mean

# Efficiency

In [None]:
co60_signal_per_Gy = (co60_data.astype(np.int64)-background_data.astype(np.int64)).clip(1,None) / dose_Co60
proton_signal_per_Gy = (proton_data.astype(np.int64)-background_data.astype(np.int64)).clip(1,None) / dose_proton_Gy
plot_data((proton_signal_per_Gy / co60_signal_per_Gy).clip(0.001,2), path='', circle_px=Circle(c.x, c.y, 80))

In [None]:
(proton_data_bg_removed_mean / dose_proton_Gy) / (co60_data_bg_removed_mean / dose_Co60)

In [None]:
plot_data(ndi.median_filter((proton_signal_per_Gy / co60_signal_per_Gy).clip(0.001,2),size=20), path='', circle_px=Circle(c.x, c.y, 80))