In [None]:
import numpy as np

### Class definitions to make things easier

In [None]:
userOccupations = {
        "administrator" : 0,
        "artist" : 1,
        "doctor" : 2,
        "educator" : 3,
        "engineer" : 4,
        "entertainment" : 5,
        "executive" : 6,
        "healthcare" : 7,
        "homemaker" : 8,
        "lawyer" : 9,
        "librarian" : 10,
        "marketing" : 11,
        "none" : 12,
        "other" : 13,
        "programmer" : 14,
        "retired" : 15,
        "salesman" : 16,
        "scientist" : 17,
        "student" : 18,
        "technician" : 19,
        "writer" : 20
    }

class User(object):

    def __init__(self, userId, age, gender, occupation, zipCode):

        self.userId = userId
        self.age = age

        """
            Gender mapping

            * M -> 0
            * F -> 1
        """

        if gender == 'M':
            self.gender = 0
        elif gender == 'F':
            self.gender = 1

        """
            Generation occupation vector. Occupation can be
            either of the following

            * administrator
            * artist
            * doctor
            * educator
            * engineer
            * entertainment
            * executive
            * healthcare
            * homemaker
            * lawyer
            * librarian
            * marketing
            * none
            * other
            * programmer
            * retired
            * salesman
            * scientist
            * student
            * technician
            * writer

            occupationVector is a vector with binary values,
            1 indicating the occupation of the user.

        """

        self.occupationVector = [0 for _ in range(0, 21)]
        self.occupationVector[userOccupations[occupation]] = 1

        self.zipCode = zipCode


class Movie(object):

    def __init__(self, itemId, year, genre):

        self.itemId = itemId
        self.year = year
        self.genreVector = genre


In [None]:
def getAllUserMetadata(u_user_path):

    """
        1 is subtracted from userid,
        to make it 0 - indexed
    """

    allUsers = []

    with open(u_user_path) as fp:
        data = fp.readlines()

    for i in range(len(data)):
        data[i] = data[i][:-1]

    for point in data:

        values = point.split('|')

        userId = int(values[0]) - 1
        userAge = int(values[1])
        userGener = values[2]
        userOccupation = values[3]
        userZipCode = values[4]

        newUser = User(userId, userAge, userGener, userOccupation, userZipCode)
        allUsers.append(newUser)

    return allUsers

def getAllMovieMetadata(u_item_path):

    """
        1 is subtracted from movieid,
        to make it 0 - indexed
    """

    allMovies = []

    with open(u_item_path) as fp:
        data = fp.readlines()

    for i in range(len(data)):
        data[i] = data[i][:-1]

    for point in data:

        values = point.split('|')
        
        itemId = int(values[0]) - 1
        
        try:
            year = int(values[2].split("-")[-1])    # 01-Jan-1994
        except:
            year = 0

        genreVector = values[5:]

        newMovie = Movie(itemId, year, genreVector)
        allMovies.append(newMovie)

    return allMovies


def getAllRatings(u_data_path):

    """
        1 is subtracted from movieid, userid
        to make it 0 - indexed

        Returns a numpy array
    """

    allRatings = []

    with open(u_data_path) as fp:
        data = fp.readlines()

    for i in range(len(data)):
        data[i] = data[i][:-1]

    for point in data:

        values = point.split()

        userId = int(values[0]) - 1
        movieId = int(values[1]) - 1
        rating = int(values[2])

        allRatings.append([userId, movieId, rating])

    return allRatings


### Code to generate folds

In [None]:
def getDatasetInfo(u_info_path):

    """
        returns dict {

            "users" : __,
            "movies" : __,
            "ratings" : __
        
        }
    """

    info = {}

    with open(u_info_path) as fp:
        data = fp.readlines()

    for i in range(len(data)):
        data[i] = data[i][:-1]

    for point in data:

        values = point.split()

        val = int(values[0])
        key = values[1]

        info[key] = val

    return info


import math

def generateFolds(ratings, k = 5):

    """
        Segments the dataset into k folds, depending
        on the number of users. For ex, for k = 5,
        generates 5 folds each having 20 % of the users

    """

    info = getDatasetInfo("../ml-100k/u.info")

    totalUsers = info["users"]
    usersInSegment = math.ceil(float(totalUsers)/k)
    
    segments = [[] for _ in range(k)]

    for point in ratings:

        userId = point[0]
        bin = int(userId/usersInSegment)

        segments[bin].append(point)

    return segments


### Code to generate feature vectors

In [None]:
def generateFeatureVector(currUser, currMovie):

    """
        feature vector

        * item genre vector (dim = 19)
        * item release year bin (dim = 12)
        * user occupation vector (dim = 21)
        * user gender (dim = 1)
        * user age bin (dim = 11)

        total dim = 64
    """

    features = []

    # generate item based features

    """
        genrevector (dim = 19) + 
        year of release? (10 year bins 1900 - 2020, dim = 12)

    """

    yearBin = [0 for _ in range(12)]

    if currMovie.year - 1900 > 0:
        yearBin[(currMovie.year - 1900)/10] = 1

    features += currMovie.genreVector
    features += yearBin

    # generate user based features

    """
        occupationVector (dim = 21) + gender (dim = 1) + 
        ageBin (0 to 109, bin size = 11, dim = 11)

        totalDim = 33

    """

    ageBin = [0 for _ in range(11)]
    ageBin[currUser.age/10] = 1

    features += currUser.occupationVector
    features += [currUser.gender]
    features += ageBin

    return np.array(features, dtype = 'float64')
    

In [None]:
# Precomputed folds
user_train = [
    [155, 108, 630, 243, 472, 47, 680, 869, 452, 245, 761, 861, 11, 575, 22, 890, 548, 160, 358, 137, 419, 938, 934, 750, 716, 686, 525, 318, 644, 60, 436, 72, 335, 867, 116, 729, 605, 118, 404, 809, 139, 738, 485, 354, 648, 434, 334, 426, 591, 345, 258, 701, 917, 647, 381, 315, 129, 320, 88, 806, 746, 413, 405, 849, 891, 57, 926, 109, 316, 580, 34, 435, 687, 518, 55, 664, 356, 699, 368, 816, 82, 524, 731, 916, 69, 86, 375, 817, 603, 327, 250, 918, 89, 212, 185, 456, 527, 31, 178, 756, 864, 238, 778, 51, 587, 941, 218, 165, 885, 851, 852, 149, 921, 126, 656, 519, 236, 4, 420, 104, 208, 805, 905, 187, 713, 510, 438, 562, 720, 824, 367, 717, 622, 592, 415, 544, 133, 317, 240, 374, 787, 417, 300, 844, 180, 735, 339, 627, 346, 793, 303, 894, 486, 623, 924, 666, 840, 388, 5, 64, 937, 425, 639, 39, 393, 296, 221, 24, 607, 507, 528, 3, 122, 758, 903, 765, 802, 35, 792, 18, 78, 594, 27, 617, 643, 879, 440, 652, 674, 808, 396, 598, 76, 7, 283, 912, 471, 551, 432, 91, 403, 257, 786, 313, 675, 84, 437, 721, 25, 704, 17, 825, 341, 706, 307, 505, 860, 788, 366, 301, 431, 384, 231, 252, 555, 255, 578, 211, 818, 759, 523, 249, 230, 71, 906, 601, 370, 718, 351, 541, 400, 297, 640, 641, 768, 111, 841, 631, 780, 833, 399, 325, 676, 877, 939, 462, 157, 621, 803, 153, 44, 74, 406, 693, 857, 458, 256, 56, 189, 506, 789, 328, 330, 290, 757, 822, 751, 58, 128, 858, 176, 340, 762, 907, 913, 459, 372, 410, 613, 285, 269, 660, 714, 931, 379, 625, 567, 12, 205, 387, 821, 439, 421, 175, 828, 897, 588, 874, 63, 408, 476, 794, 629, 115, 901, 248, 350, 769, 511, 726, 244, 929, 743, 422, 742, 584, 933, 512, 164, 744, 943, 888, 19, 397, 870, 691, 559, 694, 597, 733, 722, 352, 586, 492, 708, 854, 201, 842, 487, 349, 542, 654, 820, 583, 150, 709, 110, 651, 232, 563, 222, 935, 710, 323, 585, 514, 683, 475, 797, 54, 538, 760, 306, 196, 740, 120, 254, 347, 329, 41, 754, 736, 37, 28, 517, 247, 490, 29, 391, 807, 715, 619, 166, 468, 895, 695, 93, 446, 259, 390, 624, 810, 210, 198, 81, 739, 611, 681, 804, 235, 923, 161, 16, 188, 940, 634, 522, 655, 859, 170, 608, 574, 134, 105, 186, 669, 534, 930, 836, 229, 441, 663, 556, 190, 424, 568, 557, 73, 125, 653, 270, 80, 637, 826, 596, 814, 204, 62, 838, 284, 147, 745, 355, 638, 561, 145, 203, 407, 455, 326, 868, 628, 549, 226, 42, 464, 234, 942, 107, 357, 495, 581, 657, 679, 478, 360, 171, 48, 470, 292, 457, 474, 378, 577, 276, 764, 373, 711, 101, 138, 215, 783, 2, 273, 773, 499, 753, 609, 172, 163, 480, 798, 113, 242, 494, 831, 392, 333, 451, 920, 689, 168, 719, 922, 265, 565, 167, 866, 785, 819, 262, 401, 479, 267, 322, 402, 195, 263, 896, 233, 179, 767, 338, 331, 130, 582, 214, 650, 484, 428, 59, 579, 227, 43, 181, 70, 772, 174, 460, 902, 509, 8, 336, 9, 846, 92, 658, 537, 217, 602, 282, 862, 26, 353, 65, 530, 38, 385, 304, 266, 904, 776, 447, 604, 23, 747, 572, 97, 889, 127, 914, 908, 752, 633, 199, 246, 615, 610, 732, 409, 314, 491, 837, 272, 883, 239, 501, 553, 295, 124, 182, 66, 298, 337, 461, 812, 482, 661, 427, 13, 667, 498, 600, 33, 659, 483, 554, 829, 312, 466, 535, 260, 727, 264, 673, 823, 95, 429, 450, 219, 892, 884, 209, 237, 103, 723, 10, 893, 49, 790, 540, 433, 725, 289, 845, 690, 102, 52, 678, 672, 636, 184, 531, 83, 389, 364, 77, 200, 682, 193, 223, 173, 194, 775, 684, 552, 599, 67, 873, 569, 152, 288, 216, 131, 376, 872, 632, 595, 618, 671, 344, 855, 87, 665, 515, 774, 183, 763, 363, 411, 469, 620, 140, 547, 881, 500, 117, 251, 207, 121, 766, 206, 444, 589, 909, 30, 730, 835, 520, 141, 114, 606, 688, 359, 779, 550, 280, 593, 146, 626, 856, 148, 310, 800, 504, 332, 536, 685, 348, 362, 900, 662, 386, 177, 936, 670, 496, 724, 848, 832, 532, 463, 692, 781, 85, 308, 61, 197, 271, 612, 220, 635, 275, 801],
    [898, 343, 213, 886, 281, 94, 377, 645, 925, 649, 697, 40, 454, 6, 677, 382, 1, 15, 521, 576, 915, 755, 863, 503, 497, 533, 899, 566, 777, 278, 813, 871, 477, 573, 782, 32, 481, 700, 36, 488, 132, 90, 191, 728, 154, 830, 910, 369, 380, 321, 834, 342, 291, 414, 293, 516, 927, 68, 539, 876, 99, 703, 795, 100, 430, 443, 784, 169, 526, 847, 423, 734, 564, 287, 796, 502, 875, 158, 642, 545, 112, 570, 324, 311, 394, 448, 261, 646, 136, 144, 811, 309, 748, 241, 799, 707, 53, 489, 712, 398, 546, 590, 50, 383, 815, 558, 839, 749, 96, 878, 698, 887, 702, 853, 123, 279, 928, 268, 616, 294, 843, 45, 543, 737, 156, 412, 299, 529, 20, 224, 919, 493, 880, 771, 192, 445, 791, 98, 159, 253, 202, 453, 365, 865, 106, 827, 135, 571, 418, 14, 225, 696, 119, 467, 741, 143, 705, 508, 302, 473, 75, 286, 305, 770, 274, 371, 228, 560, 46, 882, 395, 361, 932, 142, 911, 79, 319, 416, 465, 449, 850, 21, 162, 151, 614, 513, 277, 668, 442, 808, 396, 598, 76, 7, 283, 912, 471, 551, 432, 91, 403, 257, 786, 313, 675, 84, 437, 721, 25, 704, 17, 825, 341, 706, 307, 505, 860, 788, 366, 301, 431, 384, 231, 252, 555, 255, 578, 211, 818, 759, 523, 249, 230, 71, 906, 601, 370, 718, 351, 541, 400, 297, 640, 641, 768, 111, 841, 631, 780, 833, 399, 325, 676, 877, 939, 462, 157, 621, 803, 153, 44, 74, 406, 693, 857, 458, 256, 56, 189, 506, 789, 328, 330, 290, 757, 822, 751, 58, 128, 858, 176, 340, 762, 907, 913, 459, 372, 410, 613, 285, 269, 660, 714, 931, 379, 625, 567, 12, 205, 387, 821, 439, 421, 175, 828, 897, 588, 874, 63, 408, 476, 794, 629, 115, 901, 248, 350, 769, 511, 726, 244, 929, 743, 422, 742, 584, 933, 512, 164, 744, 943, 888, 19, 397, 870, 691, 559, 694, 597, 733, 722, 352, 586, 492, 708, 854, 201, 842, 487, 349, 542, 654, 820, 583, 150, 709, 110, 651, 232, 563, 222, 935, 710, 323, 585, 514, 683, 475, 797, 54, 538, 760, 306, 196, 740, 120, 254, 347, 329, 41, 754, 736, 37, 28, 517, 247, 490, 29, 391, 807, 715, 619, 166, 468, 895, 695, 93, 446, 259, 390, 624, 810, 210, 198, 81, 739, 611, 681, 804, 235, 923, 161, 16, 188, 940, 634, 522, 655, 859, 170, 608, 574, 134, 105, 186, 669, 534, 930, 836, 229, 441, 663, 556, 190, 424, 568, 557, 73, 125, 653, 270, 80, 637, 826, 596, 814, 204, 62, 838, 284, 147, 745, 355, 638, 561, 145, 203, 407, 455, 326, 868, 628, 549, 226, 42, 464, 234, 942, 107, 357, 495, 581, 657, 679, 478, 360, 171, 48, 470, 292, 457, 474, 378, 577, 276, 764, 373, 711, 101, 138, 215, 783, 2, 273, 773, 499, 753, 609, 172, 163, 480, 798, 113, 242, 494, 831, 392, 333, 451, 920, 689, 168, 719, 922, 265, 565, 167, 866, 785, 819, 262, 401, 479, 267, 322, 402, 195, 263, 896, 233, 179, 767, 338, 331, 130, 582, 214, 650, 484, 428, 59, 579, 227, 43, 181, 70, 772, 174, 460, 902, 509, 8, 336, 9, 846, 92, 658, 537, 217, 602, 282, 862, 26, 353, 65, 530, 38, 385, 304, 266, 904, 776, 447, 604, 23, 747, 572, 97, 889, 127, 914, 908, 752, 633, 199, 246, 615, 610, 732, 409, 314, 491, 837, 272, 883, 239, 501, 553, 295, 124, 182, 66, 298, 337, 461, 812, 482, 661, 427, 13, 667, 498, 600, 33, 659, 483, 554, 829, 312, 466, 535, 260, 727, 264, 673, 823, 95, 429, 450, 219, 892, 884, 209, 237, 103, 723, 10, 893, 49, 790, 540, 433, 725, 289, 845, 690, 102, 52, 678, 672, 636, 184, 531, 83, 389, 364, 77, 200, 682, 193, 223, 173, 194, 775, 684, 552, 599, 67, 873, 569, 152, 288, 216, 131, 376, 872, 632, 595, 618, 671, 344, 855, 87, 665, 515, 774, 183, 763, 363, 411, 469, 620, 140, 547, 881, 500, 117, 251, 207, 121, 766, 206, 444, 589, 909, 30, 730, 835, 520, 141, 114, 606, 688, 359, 779, 550, 280, 593, 146, 626, 856, 148, 310, 800, 504, 332, 536, 685, 348, 362, 900, 662, 386, 177, 936, 670, 496, 724, 848, 832, 532, 463, 692, 781, 85, 308, 61, 197, 271, 612, 220, 635, 275, 801],
    [898, 343, 213, 886, 281, 94, 377, 645, 925, 649, 697, 40, 454, 6, 677, 382, 1, 15, 521, 576, 915, 755, 863, 503, 497, 533, 899, 566, 777, 278, 813, 871, 477, 573, 782, 32, 481, 700, 36, 488, 132, 90, 191, 728, 154, 830, 910, 369, 380, 321, 834, 342, 291, 414, 293, 516, 927, 68, 539, 876, 99, 703, 795, 100, 430, 443, 784, 169, 526, 847, 423, 734, 564, 287, 796, 502, 875, 158, 642, 545, 112, 570, 324, 311, 394, 448, 261, 646, 136, 144, 811, 309, 748, 241, 799, 707, 53, 489, 712, 398, 546, 590, 50, 383, 815, 558, 839, 749, 96, 878, 698, 887, 702, 853, 123, 279, 928, 268, 616, 294, 843, 45, 543, 737, 156, 412, 299, 529, 20, 224, 919, 493, 880, 771, 192, 445, 791, 98, 159, 253, 202, 453, 365, 865, 106, 827, 135, 571, 418, 14, 225, 696, 119, 467, 741, 143, 705, 508, 302, 473, 75, 286, 305, 770, 274, 371, 228, 560, 46, 882, 395, 361, 932, 142, 911, 79, 319, 416, 465, 449, 850, 21, 162, 151, 614, 513, 277, 668, 442, 155, 108, 630, 243, 472, 47, 680, 869, 452, 245, 761, 861, 11, 575, 22, 890, 548, 160, 358, 137, 419, 938, 934, 750, 716, 686, 525, 318, 644, 60, 436, 72, 335, 867, 116, 729, 605, 118, 404, 809, 139, 738, 485, 354, 648, 434, 334, 426, 591, 345, 258, 701, 917, 647, 381, 315, 129, 320, 88, 806, 746, 413, 405, 849, 891, 57, 926, 109, 316, 580, 34, 435, 687, 518, 55, 664, 356, 699, 368, 816, 82, 524, 731, 916, 69, 86, 375, 817, 603, 327, 250, 918, 89, 212, 185, 456, 527, 31, 178, 756, 864, 238, 778, 51, 587, 941, 218, 165, 885, 851, 852, 149, 921, 126, 656, 519, 236, 4, 420, 104, 208, 805, 905, 187, 713, 510, 438, 562, 720, 824, 367, 717, 622, 592, 415, 544, 133, 317, 240, 374, 787, 417, 300, 844, 180, 735, 339, 627, 346, 793, 303, 894, 486, 623, 924, 666, 840, 388, 5, 64, 937, 425, 639, 39, 393, 296, 221, 24, 607, 507, 528, 3, 122, 758, 903, 765, 802, 35, 792, 18, 78, 594, 27, 617, 643, 879, 440, 652, 674, 347, 329, 41, 754, 736, 37, 28, 517, 247, 490, 29, 391, 807, 715, 619, 166, 468, 895, 695, 93, 446, 259, 390, 624, 810, 210, 198, 81, 739, 611, 681, 804, 235, 923, 161, 16, 188, 940, 634, 522, 655, 859, 170, 608, 574, 134, 105, 186, 669, 534, 930, 836, 229, 441, 663, 556, 190, 424, 568, 557, 73, 125, 653, 270, 80, 637, 826, 596, 814, 204, 62, 838, 284, 147, 745, 355, 638, 561, 145, 203, 407, 455, 326, 868, 628, 549, 226, 42, 464, 234, 942, 107, 357, 495, 581, 657, 679, 478, 360, 171, 48, 470, 292, 457, 474, 378, 577, 276, 764, 373, 711, 101, 138, 215, 783, 2, 273, 773, 499, 753, 609, 172, 163, 480, 798, 113, 242, 494, 831, 392, 333, 451, 920, 689, 168, 719, 922, 265, 565, 167, 866, 785, 819, 262, 401, 479, 267, 322, 402, 195, 263, 896, 233, 179, 767, 338, 331, 130, 582, 214, 650, 484, 428, 59, 579, 227, 43, 181, 70, 772, 174, 460, 902, 509, 8, 336, 9, 846, 92, 658, 537, 217, 602, 282, 862, 26, 353, 65, 530, 38, 385, 304, 266, 904, 776, 447, 604, 23, 747, 572, 97, 889, 127, 914, 908, 752, 633, 199, 246, 615, 610, 732, 409, 314, 491, 837, 272, 883, 239, 501, 553, 295, 124, 182, 66, 298, 337, 461, 812, 482, 661, 427, 13, 667, 498, 600, 33, 659, 483, 554, 829, 312, 466, 535, 260, 727, 264, 673, 823, 95, 429, 450, 219, 892, 884, 209, 237, 103, 723, 10, 893, 49, 790, 540, 433, 725, 289, 845, 690, 102, 52, 678, 672, 636, 184, 531, 83, 389, 364, 77, 200, 682, 193, 223, 173, 194, 775, 684, 552, 599, 67, 873, 569, 152, 288, 216, 131, 376, 872, 632, 595, 618, 671, 344, 855, 87, 665, 515, 774, 183, 763, 363, 411, 469, 620, 140, 547, 881, 500, 117, 251, 207, 121, 766, 206, 444, 589, 909, 30, 730, 835, 520, 141, 114, 606, 688, 359, 779, 550, 280, 593, 146, 626, 856, 148, 310, 800, 504, 332, 536, 685, 348, 362, 900, 662, 386, 177, 936, 670, 496, 724, 848, 832, 532, 463, 692, 781, 85, 308, 61, 197, 271, 612, 220, 635, 275, 801],
    [898, 343, 213, 886, 281, 94, 377, 645, 925, 649, 697, 40, 454, 6, 677, 382, 1, 15, 521, 576, 915, 755, 863, 503, 497, 533, 899, 566, 777, 278, 813, 871, 477, 573, 782, 32, 481, 700, 36, 488, 132, 90, 191, 728, 154, 830, 910, 369, 380, 321, 834, 342, 291, 414, 293, 516, 927, 68, 539, 876, 99, 703, 795, 100, 430, 443, 784, 169, 526, 847, 423, 734, 564, 287, 796, 502, 875, 158, 642, 545, 112, 570, 324, 311, 394, 448, 261, 646, 136, 144, 811, 309, 748, 241, 799, 707, 53, 489, 712, 398, 546, 590, 50, 383, 815, 558, 839, 749, 96, 878, 698, 887, 702, 853, 123, 279, 928, 268, 616, 294, 843, 45, 543, 737, 156, 412, 299, 529, 20, 224, 919, 493, 880, 771, 192, 445, 791, 98, 159, 253, 202, 453, 365, 865, 106, 827, 135, 571, 418, 14, 225, 696, 119, 467, 741, 143, 705, 508, 302, 473, 75, 286, 305, 770, 274, 371, 228, 560, 46, 882, 395, 361, 932, 142, 911, 79, 319, 416, 465, 449, 850, 21, 162, 151, 614, 513, 277, 668, 442, 155, 108, 630, 243, 472, 47, 680, 869, 452, 245, 761, 861, 11, 575, 22, 890, 548, 160, 358, 137, 419, 938, 934, 750, 716, 686, 525, 318, 644, 60, 436, 72, 335, 867, 116, 729, 605, 118, 404, 809, 139, 738, 485, 354, 648, 434, 334, 426, 591, 345, 258, 701, 917, 647, 381, 315, 129, 320, 88, 806, 746, 413, 405, 849, 891, 57, 926, 109, 316, 580, 34, 435, 687, 518, 55, 664, 356, 699, 368, 816, 82, 524, 731, 916, 69, 86, 375, 817, 603, 327, 250, 918, 89, 212, 185, 456, 527, 31, 178, 756, 864, 238, 778, 51, 587, 941, 218, 165, 885, 851, 852, 149, 921, 126, 656, 519, 236, 4, 420, 104, 208, 805, 905, 187, 713, 510, 438, 562, 720, 824, 367, 717, 622, 592, 415, 544, 133, 317, 240, 374, 787, 417, 300, 844, 180, 735, 339, 627, 346, 793, 303, 894, 486, 623, 924, 666, 840, 388, 5, 64, 937, 425, 639, 39, 393, 296, 221, 24, 607, 507, 528, 3, 122, 758, 903, 765, 802, 35, 792, 18, 78, 594, 27, 617, 643, 879, 440, 652, 674, 808, 396, 598, 76, 7, 283, 912, 471, 551, 432, 91, 403, 257, 786, 313, 675, 84, 437, 721, 25, 704, 17, 825, 341, 706, 307, 505, 860, 788, 366, 301, 431, 384, 231, 252, 555, 255, 578, 211, 818, 759, 523, 249, 230, 71, 906, 601, 370, 718, 351, 541, 400, 297, 640, 641, 768, 111, 841, 631, 780, 833, 399, 325, 676, 877, 939, 462, 157, 621, 803, 153, 44, 74, 406, 693, 857, 458, 256, 56, 189, 506, 789, 328, 330, 290, 757, 822, 751, 58, 128, 858, 176, 340, 762, 907, 913, 459, 372, 410, 613, 285, 269, 660, 714, 931, 379, 625, 567, 12, 205, 387, 821, 439, 421, 175, 828, 897, 588, 874, 63, 408, 476, 794, 629, 115, 901, 248, 350, 769, 511, 726, 244, 929, 743, 422, 742, 584, 933, 512, 164, 744, 943, 888, 19, 397, 870, 691, 559, 694, 597, 733, 722, 352, 586, 492, 708, 854, 201, 842, 487, 349, 542, 654, 820, 583, 150, 709, 110, 651, 232, 563, 222, 935, 710, 323, 585, 514, 683, 475, 797, 54, 538, 760, 306, 196, 740, 120, 254, 530, 38, 385, 304, 266, 904, 776, 447, 604, 23, 747, 572, 97, 889, 127, 914, 908, 752, 633, 199, 246, 615, 610, 732, 409, 314, 491, 837, 272, 883, 239, 501, 553, 295, 124, 182, 66, 298, 337, 461, 812, 482, 661, 427, 13, 667, 498, 600, 33, 659, 483, 554, 829, 312, 466, 535, 260, 727, 264, 673, 823, 95, 429, 450, 219, 892, 884, 209, 237, 103, 723, 10, 893, 49, 790, 540, 433, 725, 289, 845, 690, 102, 52, 678, 672, 636, 184, 531, 83, 389, 364, 77, 200, 682, 193, 223, 173, 194, 775, 684, 552, 599, 67, 873, 569, 152, 288, 216, 131, 376, 872, 632, 595, 618, 671, 344, 855, 87, 665, 515, 774, 183, 763, 363, 411, 469, 620, 140, 547, 881, 500, 117, 251, 207, 121, 766, 206, 444, 589, 909, 30, 730, 835, 520, 141, 114, 606, 688, 359, 779, 550, 280, 593, 146, 626, 856, 148, 310, 800, 504, 332, 536, 685, 348, 362, 900, 662, 386, 177, 936, 670, 496, 724, 848, 832, 532, 463, 692, 781, 85, 308, 61, 197, 271, 612, 220, 635, 275, 801],
    [898, 343, 213, 886, 281, 94, 377, 645, 925, 649, 697, 40, 454, 6, 677, 382, 1, 15, 521, 576, 915, 755, 863, 503, 497, 533, 899, 566, 777, 278, 813, 871, 477, 573, 782, 32, 481, 700, 36, 488, 132, 90, 191, 728, 154, 830, 910, 369, 380, 321, 834, 342, 291, 414, 293, 516, 927, 68, 539, 876, 99, 703, 795, 100, 430, 443, 784, 169, 526, 847, 423, 734, 564, 287, 796, 502, 875, 158, 642, 545, 112, 570, 324, 311, 394, 448, 261, 646, 136, 144, 811, 309, 748, 241, 799, 707, 53, 489, 712, 398, 546, 590, 50, 383, 815, 558, 839, 749, 96, 878, 698, 887, 702, 853, 123, 279, 928, 268, 616, 294, 843, 45, 543, 737, 156, 412, 299, 529, 20, 224, 919, 493, 880, 771, 192, 445, 791, 98, 159, 253, 202, 453, 365, 865, 106, 827, 135, 571, 418, 14, 225, 696, 119, 467, 741, 143, 705, 508, 302, 473, 75, 286, 305, 770, 274, 371, 228, 560, 46, 882, 395, 361, 932, 142, 911, 79, 319, 416, 465, 449, 850, 21, 162, 151, 614, 513, 277, 668, 442, 155, 108, 630, 243, 472, 47, 680, 869, 452, 245, 761, 861, 11, 575, 22, 890, 548, 160, 358, 137, 419, 938, 934, 750, 716, 686, 525, 318, 644, 60, 436, 72, 335, 867, 116, 729, 605, 118, 404, 809, 139, 738, 485, 354, 648, 434, 334, 426, 591, 345, 258, 701, 917, 647, 381, 315, 129, 320, 88, 806, 746, 413, 405, 849, 891, 57, 926, 109, 316, 580, 34, 435, 687, 518, 55, 664, 356, 699, 368, 816, 82, 524, 731, 916, 69, 86, 375, 817, 603, 327, 250, 918, 89, 212, 185, 456, 527, 31, 178, 756, 864, 238, 778, 51, 587, 941, 218, 165, 885, 851, 852, 149, 921, 126, 656, 519, 236, 4, 420, 104, 208, 805, 905, 187, 713, 510, 438, 562, 720, 824, 367, 717, 622, 592, 415, 544, 133, 317, 240, 374, 787, 417, 300, 844, 180, 735, 339, 627, 346, 793, 303, 894, 486, 623, 924, 666, 840, 388, 5, 64, 937, 425, 639, 39, 393, 296, 221, 24, 607, 507, 528, 3, 122, 758, 903, 765, 802, 35, 792, 18, 78, 594, 27, 617, 643, 879, 440, 652, 674, 808, 396, 598, 76, 7, 283, 912, 471, 551, 432, 91, 403, 257, 786, 313, 675, 84, 437, 721, 25, 704, 17, 825, 341, 706, 307, 505, 860, 788, 366, 301, 431, 384, 231, 252, 555, 255, 578, 211, 818, 759, 523, 249, 230, 71, 906, 601, 370, 718, 351, 541, 400, 297, 640, 641, 768, 111, 841, 631, 780, 833, 399, 325, 676, 877, 939, 462, 157, 621, 803, 153, 44, 74, 406, 693, 857, 458, 256, 56, 189, 506, 789, 328, 330, 290, 757, 822, 751, 58, 128, 858, 176, 340, 762, 907, 913, 459, 372, 410, 613, 285, 269, 660, 714, 931, 379, 625, 567, 12, 205, 387, 821, 439, 421, 175, 828, 897, 588, 874, 63, 408, 476, 794, 629, 115, 901, 248, 350, 769, 511, 726, 244, 929, 743, 422, 742, 584, 933, 512, 164, 744, 943, 888, 19, 397, 870, 691, 559, 694, 597, 733, 722, 352, 586, 492, 708, 854, 201, 842, 487, 349, 542, 654, 820, 583, 150, 709, 110, 651, 232, 563, 222, 935, 710, 323, 585, 514, 683, 475, 797, 54, 538, 760, 306, 196, 740, 120, 254, 347, 329, 41, 754, 736, 37, 28, 517, 247, 490, 29, 391, 807, 715, 619, 166, 468, 895, 695, 93, 446, 259, 390, 624, 810, 210, 198, 81, 739, 611, 681, 804, 235, 923, 161, 16, 188, 940, 634, 522, 655, 859, 170, 608, 574, 134, 105, 186, 669, 534, 930, 836, 229, 441, 663, 556, 190, 424, 568, 557, 73, 125, 653, 270, 80, 637, 826, 596, 814, 204, 62, 838, 284, 147, 745, 355, 638, 561, 145, 203, 407, 455, 326, 868, 628, 549, 226, 42, 464, 234, 942, 107, 357, 495, 581, 657, 679, 478, 360, 171, 48, 470, 292, 457, 474, 378, 577, 276, 764, 373, 711, 101, 138, 215, 783, 2, 273, 773, 499, 753, 609, 172, 163, 480, 798, 113, 242, 494, 831, 392, 333, 451, 920, 689, 168, 719, 922, 265, 565, 167, 866, 785, 819, 262, 401, 479, 267, 322, 402, 195, 263, 896, 233, 179, 767, 338, 331, 130, 582, 214, 650, 484, 428, 59, 579, 227, 43, 181, 70, 772, 174, 460, 902, 509, 8, 336, 9, 846, 92, 658, 537, 217, 602, 282, 862, 26, 353, 65]
]
user_test = [[i for i in range(1, 944) if i not in user_train[j]] for j in range(0, 5)]

In [None]:
from sklearn.neural_network import MLPRegressor

In [None]:
for units in [100 * i for i in range(1, 5)]:

    allMoives = getAllMovieMetadata("../ml-100k/u.item")
    allUsers = getAllUserMetadata("../ml-100k/u.user")
    allRatings = getAllRatings("../ml-100k/u.data")
    allMae = []

    with open('log2', 'a') as fp:
        fp.write(str(units) + '\n')

    for i in range(len(user_train)):

        print "Fold %d of %d" % (i, 5) 

        train = user_train[i]
        test = user_test[i]

        """
                segments already has users divided. Let test
                be the ith segment, and all other be in train
        """

        print "created training and test rating sets ..."

        trainFeatureVectors = []
        trainLabels = []

        print "creating training feature vectors ..."


        predicted = []
        correct = []
        testX = []
        testY = []


        for point in allRatings:

            currUser = allUsers[point[0]]
            currMovie = allMoives[point[1]]

            if point[0] in train:

                trainFeatureVectors.append(generateFeatureVector(currUser, currMovie))
                trainLabels.append(int(point[2]) - 1)

            elif point[0] in test:
                testX.append(generateFeatureVector(currUser, currMovie))
                testY.append(int(point[2]) - 1) 


        testX = np.array(testX)
        testY = np.array(testY)
        # testY = np.array(testY).reshape(-1, 1)

        trainX = np.array(trainFeatureVectors)
        trainY = np.array(trainLabels)
        # trainY = np.array(trainY).reshape(-1, 1)

        reg = MLPRegressor(hidden_layer_sizes = (units, units))
        reg.fit(trainX, trainY)
        print 'Fit over'

        predicted = reg.predict(testX)
        for h in range(len(predicted)):
            if predicted[h] < 0:
                predicted[h] = 1
            elif predicted[h] > 5:
                predicted[h] = 5
            else:
                predicted[h] = round(predicted[h])
        mae = mean_absolute_error(testY, predicted)

    with open('log2', 'a') as fp:
        fp.write(str(mae) + '\n')

        allMae.append(mae)

        print "error calculated, = %f" % (mae)

    print "all folds done, average mae = %f" % (sum(allMae) /len(allMae))

    with open('log2', 'a') as fp:
        fp.write(str((sum(allMae) /len(allMae)) / 4.) + '\n')


