In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import pandas as pd
import lzma

In [2]:
def min_max_normalize_row(segment):
    min_val = segment.min()
    max_val = segment.max()
    return (segment - min_val) / (max_val - min_val) if max_val > min_val else segment

In [3]:
def find_segments(labels, segment_length, min_region_length):
    segments = []  # List to store the valid segments (i, j)
    
    # Iterate over the labels array to find potential segments
    for i in range(len(labels)):
        if labels[i] == 1:  # Start of a potential segment
            for j in range(i + min_region_length + 1, len(labels)):  # j must be at least min_region_length indices after i
                if labels[j] == 1:  # End of a potential segment
                    # Check the condition for labels[k] where i < k < j
                    valid_segment = True
                    for k in range(i + 1, j):  # Check indices between i and j
                        if labels[k] == 0 and labels[k - segment_length] != 1:
                            valid_segment = False
                            break
                    if valid_segment:
                        segments.append((i, j))
    
    # Sort the segments by their start index
    segments.sort()

    # Merge segments
    merged_segments = []
    for start, end in segments:
        if not merged_segments or merged_segments[-1][1] + segment_length < start:  # Adjusted condition to check overlap or adjacency
            merged_segments.append((start, end))
        else:  # Overlap or adjacency, so merge
            merged_segments[-1] = (merged_segments[-1][0], max(merged_segments[-1][1], end))

    return merged_segments

In [4]:
model = torch.jit.load("trained_model.pth")
model.eval()
segment_length = model.fc1.in_features
# min_region_length = int(segment_length/4)
min_region_length = 2

In [5]:
with lzma.open("../../data/cancer/profiles.csv.xz", 'rt') as file:
    signal_df = pd.read_csv(file)
list_seqID = sorted(signal_df['sequenceID'].unique())[::2]

labels_df = pd.read_csv("../../data/cancer/labels.csv")

In [None]:
for i, seqID in enumerate(list_seqID):
    # Check if sequence length is smaller than the segment_length
    seq_df = signal_df[signal_df['sequenceID'] == seqID]
    seq = seq_df['signal'].to_numpy()
    if len(seq) < segment_length:
        continue  # Skip this sequence if its length is smaller than the segment length

    # Break sequence into segments
    segments = np.array([seq[i:i + segment_length] for i in range(len(seq) - segment_length + 1)])
    segments_normalized = np.array([min_max_normalize_row(segment) for segment in segments])
    segments_tensor = torch.tensor(segments_normalized, dtype=torch.float32)
    
    # Calculate labels
    labels = (model(segments_tensor) > 0.5).int()

    # Get the regions
    regions = find_segments(labels, segment_length, min_region_length)

    # Create the figure and axes for two stacked plots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 4), sharex=True)

    # Plot the sequence in the first subplot
    ax1.plot(seq, label='sequence', color='black')
    for (rect_start, rect_end) in regions:
        ax1.axvspan(rect_start, rect_end + segment_length, ymin=np.floor(min(seq)), ymax=np.ceil(max(seq)), color='red', alpha=0.3)
    ax1.set_xlabel('Index')
    ax1.set_ylabel('Value')
    ax1.set_title(f"Sequence: {i}-{seqID}")
    ax1.legend(handles=[
        Line2D([0], [0], color='black', lw=2, label='Sequence'),
        Patch(color='red', alpha=0.3, label='Changepoint Region')
    ], loc='upper left', bbox_to_anchor=(1, 1), handlelength=2)

    # Now process the second plot (signal with labeled regions)
    seq_label_df = labels_df[(labels_df['sequenceID'] == seqID) | (labels_df['sequenceID'] == (seqID[:-3] + ".F2"))]
    sequence_id = seq_df['sequenceID'].unique()[0][:-3]
    
    # Plot the signal in the second subplot
    ax2.plot(seq_df['position'] - 1, seq_df['signal'], color='blue', label=f"{sequence_id}")
    signal_min = seq_df['signal'].min()
    signal_max = seq_df['signal'].max()
    
    # Add rectangles for each labeled region
    for _, row in seq_label_df.iterrows():
        color = 'pink' if row['changes'] == 0 else 'red'
        start = row['start'] - 0.7
        end = row['end'] - 1.3
        width = end - start
        ax2.add_patch(plt.Rectangle((start, signal_min), width, signal_max - signal_min, color=color, alpha=0.3))
    
    ax2.set_xlabel('Position')
    ax2.set_ylabel('Signal')
    ax2.legend()

    # Save the combined figure as a PNG
    plt.tight_layout()  # Ensure proper spacing between plots
    plt.savefig(f'figures_cancer/sequence_{sequence_id}.png', bbox_inches='tight')
    plt.close(fig)  # Close the figure to free up memory