In [1]:
import linecache
import csv
import pandas as pd 
import numpy as np

In [61]:
class Dataset:
    def __init__(self, filename: str, rows: list=None):
        """
        Dataset method can either be initialized with a list of rows (mutable by changing the rows attribute),
        or a new list of rows may be passed in for each method that requires it, but not both (this would be ambiguous).
        
        Parameters:
        filename: Path to csv file 
        rows (optional): List of rows to initialize the dataset with
        
        Returns:
        None
        """
        # Private attributes
        self._filename = filename
        self._total_data = self._numline(filename)
        
        # Public attributes 
        self.rows = rows
        
        # Public attributes (Pandas API-like)
        self.index = rows
        self.columns = self._get_columns()
        self.shape = (self._total_data, len(self.columns))
        
    # Python dunder methods
    
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            step = (1 if idx.step == None else idx.step)
            return np.array([self._getline(i) for i in range(idx.start, idx.stop, step)]).astype(float)
        elif isinstance(idx, (list, range)):
            return np.array([self._getline(i) for i in idx]).astype(float)
        elif isinstance(idx, int):
            return np.array(self._getline(idx)).astype(float)
        else:
            raise TypeError(f"Index must be list or int, not {type(idx).__name__}")

    def __len__(self):
        return self._total_data

    def __str__(self):
        if self.rows is not None:
            return str(self.__getitem__(self.rows))
        else:
            return 'Dataset()'
            
    def __repr__(self):
        return self.__str__()
    
    def _getline(self, idx):
        """
        Returns a line from a csv file as a list of strings (not type-checked)
        
        Parameters:
        idx: Row to return from file 
        
        Returns:
        list: Row of file with each comma-separated value as a distinct value in the list 
        """
        line = linecache.getline(self._filename, idx + 2)
        csv_data = csv.reader([line])
        data = [x for x in csv_data][0]
        return data
    
    def _numline(self, filename):
        """
        Gets the number of lines in a file, should only be used for getting the total number of rows on object initialization
        
        Parameters:
        filename: Path to the file to get the number of lines from
        
        Returns:
        n: Number of lines in the file
        """
        n = 0
        with open(filename, "r") as f:
            n = len(f.readlines()) - 1
        return n
    
    def _row_get(self, rows: list):
        """
        Returns rows from a file, either with a passed list or from the list of rows upon object initialization.
        Also performs error checking to make sure either rows were set upon initialization or passed, but not both or neither. 
        
        Parameters:
        rows: List of rows
        
        Returns:
        list: Array of row values from file 
        """
        
        if self.rows is None and rows is None:
            raise ValueError(
                f"""{self.__class__} object was not initialized with a list of rows.
                Either reinitialize with a list or rows or pass a list of rows to this method."""
            )
        if self.rows is not None and rows is not None:
            raise ValueError(
                f"""{self.__class__} object was initialized with a list of rows. Therefore, a list of rows may not be 
                passed to this method. Either reinitialize without a defined list of rows or do not pass a list into this method. """
            )
        
        return rows if rows != None else self.rows

    def _get_columns(self):
        """
        Get all the columns of the csv
        
        Parameters:
        None
        
        Returns:
        list: List of column names as strings
        """
        line = linecache.getline(self._filename, 1)
        csv_data = csv.reader([line])
        return [x for x in csv_data][0]
    
    def sum(self, rows=None, axis=0):
        """Sums the given rows by the given axis"""
        rows = self._row_get(rows)
        
        return np.sum(self[rows], axis=axis)
        
    def nlargest(self, rows=None, n=20, axis=0, ascending=False):
        """
        Gets the n largest rows or columns (summed), depending on the axis 
        """
        
        rows = self._row_get(rows)
        s = np.sum(self[rows], axis=axis)
        
        if axis == 0:
            data = [self.columns[idx] for idx in np.argsort(s)[-n: ]]
        else:
            data = np.argsort(s)[-n: ]
            
        return data if ascending else data[::-1]
    
    def nsmallest(self, rows=None, n=20, axis=0, ascending=False):
        """
        Gets the n smallest rows or columns (summed), depending on the axis 
        """
        
        rows = self._row_get(rows)
        s = np.sum(self[rows], axis=axis)
        print(s)
        if axis == 0:
            data = [self.columns[idx] for idx in np.argsort(s)[0: n]]
        else:
            data = np.argsort(s)[0: n]
            
        return data[::-1] if ascending else data

In [33]:
data = Dataset('organoid_reduction_neighbors_100_components_50.csv', rows=range(0, 1000))
data.columns = [f'col_{x}' for x in data.columns]

In [45]:
primary = Dataset('../../organoid-classification/data/processed/primary.csv', rows=range(0, 100))

In [50]:
test = data.sum()


array([5322.8303533, 3608.4783642, 2507.6654136, 6370.1733004,
       5712.3879036, 2609.2255756, 4878.1123719, 6309.2513283,
       5143.1252227, 5877.3215348, 3279.4795299, 4140.8169758,
       5033.1119328, 5069.9351873, 5057.1384838, 5265.4780673,
       4161.5592598, 4978.5823227, 5829.664766 , 5689.5662662,
       2999.7156943, 1717.6573407, 4449.7961611, 6852.3026499,
       4522.9429672, 4689.4657695, 4076.226207 , 4391.5872259,
       6749.661317 , 3880.5927188, 6891.3960118, 4409.8113568,
       5100.5035325, 5882.3415351, 6782.2563237, 8935.8695735,
       4905.0859755, 3610.8250737, 5059.8969808, 7091.2387445,
       5479.8825847, 2641.171667 , 5033.7170247, 4911.3634783,
       5109.7200637, 5360.0432647, 6761.0639403, 6786.0679875,
       6827.6644248, 5296.5538224])

In [63]:
primary.nlargest(ascending=False)

['RPL7',
 'RPS15',
 'RPL32',
 'RPS18',
 'RPS14',
 'STMN1',
 'RPL21',
 'SOX4',
 'TUBA1A',
 'ACTB',
 'RPL41',
 'RPL10',
 'TMSB10',
 'RPL39',
 'RPS19',
 'RPS27',
 'PTMA',
 'RPL34',
 'TMSB4X',
 'MALAT1']

In [86]:
len(primary.sum(rows=range(0, 10), axis=1))

10

In [88]:
primary[0: 5]

array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 1.26626311, ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [102]:
t = np.array([[1,0,0], [0,1,0], [0,0,9], [2,1,0]])

In [109]:
t

array([[1, 0, 0],
       [0, 1, 0],
       [0, 0, 9],
       [2, 1, 0]])

In [112]:
np.einsum('ij->j', t)

array([3, 2, 9])

In [117]:
df.sum(axis=0)

0     1.136211e+06
1     1.073645e+06
2     8.114607e+05
3     1.511266e+06
4     1.177429e+06
5     6.041218e+05
6     1.183298e+06
7     1.370878e+06
8     1.244941e+06
9     1.505425e+06
10    1.130370e+06
11    1.047262e+06
12    1.118208e+06
13    1.119800e+06
14    1.298040e+06
15    1.209784e+06
16    1.036354e+06
17    1.226688e+06
18    1.373266e+06
19    1.389661e+06
20    8.487344e+05
21    3.092742e+05
22    9.949686e+05
23    1.438760e+06
24    1.075432e+06
25    9.880509e+05
26    8.638106e+05
27    9.781714e+05
28    1.591229e+06
29    9.606378e+05
30    1.612723e+06
31    1.117002e+06
32    1.211222e+06
33    1.425658e+06
34    1.566990e+06
35    2.120810e+06
36    1.153098e+06
37    8.628491e+05
38    1.152845e+06
39    1.787684e+06
40    1.195477e+06
41    6.364672e+05
42    1.185073e+06
43    1.132126e+06
44    1.309915e+06
45    1.246404e+06
46    1.606879e+06
47    1.606189e+06
48    1.632589e+06
49    1.250043e+06
dtype: float64