In [1]:
import os
import sys
import random
import csv
import shutil
import string
import numpy as np
from pyoperant import utils, components, local, hwio

In [2]:
from pyoperant.behavior import two_alt_choice
import copy
import string

# 1. New 2AC child class
1. new trial type ("test")
2. new block design (test cannot be end of a VR, how?)
3. how does correction work with test
4. 

# 2. New reinforcement schedule  
1. changes consequate 

2. 

In [None]:
class FixedRatioSchedule(BaseSchedule):
    """Maintains logic for deciding whether to consequate trials.
    This class implements a fixed ratio schedule, where a reward reinforcement
    is provided after every nth correct response, where 'n' is the 'ratio'.
    Incorrect trials are always reinforced.
    Methods:
    consequate(trial) -- returns a boolean value based on whether the trial
        should be consequated.
    """
    def __init__(self, ratio=1):
        super(FixedRatioSchedule, self).__init__()
        self.ratio = max(ratio,1)
        self._update()

    def _update(self):
        self.cumulative_correct = 0
        self.threshold = self.ratio

    def consequate(self,trial):
        assert hasattr(trial, 'correct') and isinstance(trial.correct, bool)
        if trial.correct==True:
            self.cumulative_correct += 1
            if self.cumulative_correct >= self.threshold:
                self._update()
                return True
            else:
                return False
        elif trial.correct==False:
            self.cumulative_correct = 0
            return True
        else:
            return False

    def __unicode__(self):
        return "FR%i" % self.ratio

class VariableRatioSchedule(FixedRatioSchedule):
    """Maintains logic for deciding whether to consequate trials.
    This class implements a variable ratio schedule, where a reward
    reinforcement is provided after every a number of consecutive correct
    responses. On average, the number of consecutive responses necessary is the
    'ratio'. After a reinforcement is provided, the number of consecutive
    correct trials needed for the next reinforcement is selected by sampling
    randomly from the interval [1,2*ratio-1]. e.g. a ratio of '3' will require
    consecutive correct trials of 1, 2, 3, 4, & 5, randomly.
    Incorrect trials are always reinforced.
    Methods:
    consequate(trial) -- returns a boolean value based on whether the trial
        should be consequated.
    """
    def __init__(self, ratio=1):
        super(VariableRatioSchedule, self).__init__(ratio=ratio)

    def _update(self):
        ''' update min correct by randomly sampling from interval [1:2*ratio)'''
        self.cumulative_correct = 0
        self.threshold = random.randint(1, 2*self.ratio)

    def __unicode__(self):
        return "VR%i" % self.ratio

In [None]:
class soundtexture_2AC(two_alt_choice.TwoAltChoiceExp):
    """
	Expt presents a number of repetitions of a single motif, then a pause, then another motif. The animal is asked whether the second motif is the same or different.
    """

    def __init__(self, *args, **kwargs):

        super(soundtexture_2AC, self).__init__(*args, **kwargs)
        # save beginning parameters to reset each session
        self.starting_params = copy.deepcopy(self.parameters)
        self.get_conditions()  # get conditions for trial in block (e.g. ABA, BAB, BBA, ...)
        self.get_motifs()  # builds specific motif sequences (e.g. [A12, B1, A40])
        self.build_block()  # build wav files from the generated sequences


    def get_conditions(self):
        """
        generates a random 100 trial block (l, r, l, r etc)
        """
        self.trial_types = np.matrix('0; 1')
        #
        self.trials = []
        for i in range(100):
            self.trials.append(random.randrange(0, 2, 1))
        self.trial_output = [self.parameters["category_conditions"][i]["class"]
                             for i in self.trials]

    def get_motifs(self):
        """ 
        2. generate specific stim sequence e.g. [A12, B1, A40]
        """
        self.motifs = []
        molen = self.parameters["current_available_motifs"]
        left_stims = self.parameters["left_stims"]
        right_stims = self.parameters["right_stims"]
        motif_seq = self.trials

        for i in motif_seq:

            if i == 1: #go right if 1 go left if 0
                thisstim = right_stims[str(random.randrange(0,molen,1))]

            else:
                thisstim = left_stims[str(random.randrange(0,molen,1))]

            self.motifs.append(thisstim)

    def build_block(self):
        """ 
        Adds stims and trial classes to parameters
        """
        for i, j in zip(self.motifs, self.trials):
            cur_dict = {}
            cur_dict["class"] = "L" if j == 1 else "R"
            cur_dict["stim_name"] = i
            self.parameters["stims"][i] = self.parameters["stim_path"]+"/"+i
            self.parameters["block_design"]["blocks"]["default"]["conditions"].append(cur_dict)

    def session_post(self):
        """ 
        Closes out the sessions
        """
        self.parameters = copy.deepcopy(
            self.starting_params)  # copies the original conditions every new session
        #self.clear_song_folder()  # removes everything in the current song folder
        self.get_conditions()  # get conditions for trial in block (e.g. ABA, BAB, BBA, ...)
        self.get_motifs()  # builds specific motif sequences (e.g. [A12, B1, A40])
        self.build_block()  # magp - build wav files from the generated sequences
        self.log.info('ending session')
        self.trial_q = None
        return None
            
    def clear_song_folder(self):
        """ 
        deletes everything in the song folder
        """
        folder = self.parameters["stim_path"] + "/Generated_Songs/"
        for the_file in os.listdir(folder):
            file_path = os.path.join(folder, the_file)
            try:
                if os.path.isfile(file_path):
                    os.unlink(file_path)
                    # elif os.path.isdir(file_path): shutil.rmtree(file_path)
            except Exception, e:
                print e
