In [5]:
import numpy as np
import matplotlib.pyplot as plt
import panel as pn
pn.extension()

from ratmoseq_extract.flip import FlipClassifierWidget, create_training_dataset, train_classifier, save_classifier

In [6]:
data_dir = '/n/groups/datta/jlove/data/rat_seq/rat_seq_paper/data/4weeks'
frames_name = 'frames'

In [7]:
widget = FlipClassifierWidget(data_dir, frames_name)

In [8]:
widget.show()

In [21]:
widget.save_frame_ranges()

In [13]:
# training_set = create_training_dataset(path)

In [34]:
# !mv /n/groups/datta/jlove/data/rat_seq/rat_seq_paper/data/training_data.npz /n/groups/datta/jlove/data/rat_seq/rat_seq_paper/data/14weeks_training_data.npz

In [35]:
!ls /n/groups/datta/jlove/data/rat_seq/rat_seq_paper/data

14weeks				      adult_control
14weeks-flip.p			      mice_control
14weeks-flip-training-frame-ranges.p  mice_control_v2
14weeks_training_data.npz	      nor
4weeks				      _pca
4weeks-flip-training-frame-ranges.p   tmp
9weeks				      untar.py


In [36]:
path = '/n/groups/datta/jlove/data/rat_seq/rat_seq_paper/data/4weeks-flip-training-frame-ranges.p'

In [37]:
training_set = create_training_dataset(path)

Training data shape: (18114, 256, 256); memory usage: 37.99 GB


In [38]:
# create_training_dataset??

In [39]:
# ls /n/groups/datta/jlove/data/rat_seq/rat_seq_paper/data/

In [40]:
data = np.load(training_set)
frames = data["frames"]
flipped = data["flipped"]

In [41]:
flipped.shape, frames.shape

((18114,), (18114, 256, 256))

In [42]:
from sklearn.decomposition import PCA
def flatten(array: np.ndarray) -> np.ndarray:
    return array.reshape(len(array), -1)

In [43]:
# pca = PCA(n_components=20)

In [44]:
# pca.fit(flatten(frames[-len(frames) // 3 :]))

In [45]:
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score, KFold
from sklearn.preprocessing import StandardScaler, FunctionTransformer

In [46]:
from ratmoseq_extract.flip import batch_apply_pca, flatten

In [47]:
# def batch_apply_pca(frames: np.ndarray, pca, batch_size: int = 1000) -> np.ndarray:
#     output = []
#     if len(frames) < batch_size:
#         return pca.transform(flatten(frames)).astype(np.float32)

#     for arr in np.array_split(frames, len(frames) // batch_size):
#         output.append(pca.transform(flatten(arr)).astype(np.float32))
#     return np.concatenate(output, axis=0).astype(np.float32)

In [48]:
# classifier = 'RF'
# batch_size=1000

# pipeline = make_pipeline(
#     FunctionTransformer(batch_apply_pca, kw_args={'pca':pca}, validate=False),
#     StandardScaler(),
#     RandomForestClassifier(n_estimators=150)
# )

# pipeline = make_pipeline(
#     FunctionTransformer(flatten, validate=False),
#     PCA(n_components=20), 
#     StandardScaler(),
#     RandomForestClassifier(n_estimators=150)
# )
from sklearn.base import BaseEstimator, TransformerMixin
# Custom transformer to reshape 3D data to 2D in batches
class BatchPCA(BaseEstimator, TransformerMixin):
    def __init__(self, pca, batch_size=1000):
        self.pca = pca
        self.batch_size = batch_size

    def fit(self, X, y=None):
        # Flatten and fit PCA on the data in batches
        n_samples, _, _ = X.shape
        self.pca.fit(flatten(X[:-n_samples // 3]))  # Fit PCA on the flattened data
        return self

    def transform(self, X):
        # Transform data in batches
        n_samples, _, _ = X.shape
        output = []

        # Process in batches to avoid memory overload
        for i in range(0, n_samples, self.batch_size):
            transformed_batch = self.pca.transform(flatten(X[i:i + self.batch_size]))
            output.append(transformed_batch)

        return np.concatenate(output, axis=0)

# Define PCA and other pipeline components
pca = PCA(n_components=20)
batch_size = 100

pipeline = make_pipeline(
    BatchPCA(pca, batch_size=batch_size),  # Apply PCA in batches
    StandardScaler(),
    RandomForestClassifier(n_estimators=150)
)

In [None]:
# accuracy = cross_val_score(
#     pipeline, frames, flipped, cv=KFold(n_splits=4, shuffle=True, random_state=0), error_score='raise'
# )

In [None]:
# accuracy

In [49]:
pipeline.fit(frames, flipped)

In [50]:
save_classifier(pipeline, widget.data_path.parent / "4weeks-flip.p")

Classifier saved to /n/groups/datta/jlove/data/rat_seq/rat_seq_paper/data/4weeks-flip.p


In [51]:
import joblib

In [55]:
clf = joblib.load(widget.data_path.parent / "4weeks-flip.p")

In [56]:
clf

In [57]:
preds = clf.predict(frames)

In [62]:
np.all(preds == flipped)

True

In [None]:
clf = train_classifier(training_set)

In [None]:
-len(frames) // 3 

In [None]:
frames.shape[0]-len(frames) // 3 

In [None]:
clf = train_classifier(training_set)

In [None]:
save_classifier(clf, widget.data_path.parent / "4weeks-flip.p")

In [None]:
from pathlib import Path

In [None]:
files = Path(data_dir).glob('**/results_00.h5')

In [None]:
files = list(files)

In [None]:
import h5py

In [None]:
with h5py.File(files[0], 'r') as f:
    # print(f.keys())
    print(f['raw_frames'].shape)
    plt.imshow(f['raw_frames'][0])
    plt.colorbar()