# Databasses and sql

In [2]:
from typing import Tuple, Sequence, List, Any, Callable, Dict, Iterator

from collections import defaultdict

In [47]:
Row = Dict[str, Any]
WhereClause = Callable[[Row], bool]
HavingClause = Callable[[List[Row]], bool]

In [48]:
class Table:

    def __init__(self, columns: List[str], types: List[type]) -> None:

        assert len(columns) == len(types)

        self.columns = columns
        self.types = types
        self.rows: List[Row] = []

    
    def col2type(self, col: str) -> type:
        idx = self.columns.index(col)
        return self.types[idx]
    
    def insert(self, values: list) -> None:
        if len(values) != len(self.types):
            raise ValueError(f"You need to provide {len(self.types)} values")

        for value, typ3 in zip(values, self.types):
            if not isinstance(value, typ3) and value is not None:
                raise TypeError("Unexpected type")
        
        self.rows.append(dict(zip(self.columns, values)))

    
    def update(self, updates: Dict[str, Any],
                predicate: WhereClause = lambda row: True):

                for column, new_value in updates.items():
                    if column not in self.columns:
                        raise ValueError("INvalid column")
                    typ3 = self.col2type(column)
                    if not isinstance(new_value, typ3) and new_value is not None:
                        raise TypeError("Unexpected type")

                    for row in self.rows:
                        if predicate(row):
                            for column, new_value in updates.items():
                                row[column] = new_value


    def delete(self, predicate: WherClause = lambda row: True) -> None:
        self.rows = [row for row in self.rows if not predicate(row)]

    def select(self, keep_columns: List[str] = None,
               additional_columns: Dict[str, Callable] = None) -> 'Table':
               if keep_columns is None:
                   keep_columns = self.columns
               
               if additional_columns is None:
                    additional_columns = {}
                
               new_columns = keep_columns + list(additional_columns.keys())
               keep_types = [self.col2type(col) for col in keep_columns]

               add_types = [calculation.__annotations__['return'] for calculation in additional_columns.Values()]

               new_table = Table(new_columns, keep_types + add_types)

               for row in self.rows:
                    new_row = [row[column] for column in keep_columns]
                    for column_name, calculation in additional_columns.items():
                        new_row.append(claalculation(row))
                        new_table.insert(new_row)
               return new_table

    def where(self, predicate: WhereClause = lambda row: True) -> 'Table':
        """Return only the rows that satisfy the supplied predicate"""
        where_table = Table(self.columns, self.types)
        for row in self.rows:
            if predicate(row):
                values = [row[column] for column in self.columns]
                where_table.insert(values)
        return where_table

    def limit(self, num_rows: int) -> 'Table':
        """Return only the first `num_rows` rows"""
        limit_table = Table(self.columns, self.types)
        for i, row in enumerate(self.rows):
            if i >= num_rows:
                break
            values = [row[column] for column in self.columns]
            limit_table.insert(values)
        return limit_table

    def group_by(self,
                 group_by_columns: List[str],
                 aggregates: Dict[str, Callable],
                 having: HavingClause = lambda group: True) -> 'Table':

        grouped_rows = defaultdict(list)

        # Populate groups
        for row in self.rows:
            key = tuple(row[column] for column in group_by_columns)
            grouped_rows[key].append(row)

        # Result table consists of group_by columns and aggregates
        new_columns = group_by_columns + list(aggregates.keys())
        group_by_types = [self.col2type(col) for col in group_by_columns]
        aggregate_types = [agg.__annotations__['return']
                           for agg in aggregates.values()]
        result_table = Table(new_columns, group_by_types + aggregate_types)

        for key, rows in grouped_rows.items():
            if having(rows):
                new_row = list(key)
                for aggregate_name, aggregate_fn in aggregates.items():
                    new_row.append(aggregate_fn(rows))
                result_table.insert(new_row)

        return result_table

    def order_by(self, order: Callable[[Row], Any]) -> 'Table':
        new_table = self.select()       # make a copy
        new_table.rows.sort(key=order)
        return new_table

    def join(self, other_table: 'Table', left_join: bool = False) -> 'Table':

        join_on_columns = [c for c in self.columns           # columns in
                           if c in other_table.columns]      # both tables

        additional_columns = [c for c in other_table.columns # columns only
                              if c not in join_on_columns]   # in right table

        # all columns from left table + additional_columns from right table
        new_columns = self.columns + additional_columns
        new_types = self.types + [other_table.col2type(col)
                                  for col in additional_columns]

        join_table = Table(new_columns, new_types)

        for row in self.rows:
            def is_join(other_row):
                return all(other_row[c] == row[c] for c in join_on_columns)

            other_rows = other_table.where(is_join).rows

            # Each other row that matches this one produces a result row.
            for other_row in other_rows:
                join_table.insert([row[c] for c in self.columns] +
                                  [other_row[c] for c in additional_columns])

            # If no rows match and it's a left join, output with Nones.
            if left_join and not other_rows:
                join_table.insert([row[c] for c in self.columns] +
                                  [None for c in additional_columns])

        return join_table
    
    def __getitem__(self, idx: int) -> Row:
        return self.rows[idx]
    
    def __iter__(self) -> Iterator[Row]:
        return iter(self.rows)

    def __len__(self) -> int:
        return len(self.rows)
    
    def __repr__(self):
        rows = "\n".join(str(row) for row in self.rows)
        return f"{self.columns}\n{rows}"


In [34]:
users = Table(['user_id', 'name', 'num_friends'], [int, str, int])

In [35]:
users.insert([0, 'hero', 0])
users.insert([1, 'dunn', 2])
users.insert([2, 'Lal', 3])
users.insert([3, 'niyas', 4])

In [36]:
print(users)

['user_id', 'name', 'num_friends']
{'user_id': 0, 'name': 'hero', 'num_friends': 0}
{'user_id': 1, 'name': 'dunn', 'num_friends': 2}
{'user_id': 2, 'name': 'Lal', 'num_friends': 3}
{'user_id': 3, 'name': 'niyas', 'num_friends': 4}


In [37]:
print(users[1]['num_friends'])
users.update({'num_friends':3}, lambda row: row['user_id'] == 1)

2


In [38]:
print(users[1]['num_friends'])

3


In [39]:
users.delete(lambda row: row['user_id'] == 1)
users.delete()

In [40]:
print(users)

['user_id', 'name', 'num_friends']

